# Result analysis

In [1]:
import glob
import json
import os
from tempfile import TemporaryDirectory

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions as D
from bouncing_ball.dataloaders.bouncing_data import BouncingBallDataLoader
from kalman_vae import KalmanVariationalAutoencoder
from matplotlib.colors import LinearSegmentedColormap
from moviepy.editor import ImageSequenceClip
from natsort import natsorted
from sample_control import SampleControl
from torch.utils.data import DataLoader

In [2]:
dtype = torch.float64
device = "cuda:1"

# Interpolation

In [3]:
_dataloader_test = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/test"
)


def sequence_first_collate_fn(batch):
    data = torch.Tensor(np.stack(batch, axis=0))
    # data.shape: [batch size, sequence length, channels, height, width]
    # Reshape to [sequence length, batch size, channels, height, width]
    data = data.permute(1, 0, 2, 3, 4)
    return data


dataloader_test = DataLoader(
    _dataloader_test, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)

In [4]:
for data in dataloader_test:
    break

In [5]:
seq_length, batch_size, image_channels, *image_size = data.shape

In [6]:
checkpoint_paths = [
    file
    for file in natsorted(
        glob.glob("checkpoints/experiment_20231110_220254/state-*.pth")
    )
]
checkpoint = torch.load(checkpoint_paths[-1], map_location=device)

In [7]:
kvae = KalmanVariationalAutoencoder(
    image_size=image_size,
    image_channels=image_channels,
    a_dim=2,
    z_dim=4,
    K=3,
    decoder_type="bernoulli",
).to(device=device, dtype=dtype)

In [8]:
kvae.load_state_dict(checkpoint["model_state_dict"], strict=False)

<All keys matched successfully>

In [9]:
mask_length = 42

In [None]:
def create_continuous_mask(seq_length, mask_length, batch_size, device, dtype):
    lst = [1.0] * seq_length
    start_index = (seq_length - mask_length) // 2
    for i in range(start_index, start_index + mask_length):
        lst[i] = 0.0
    return (
        torch.tensor(lst)
        .repeat(batch_size, 1)
        .transpose(0, 1)
        .to(device=device, dtype=dtype)
    )

def create_random_mask(seq_length, batch_size, mask_rate, device, dtype):
    return (torch.rand((seq_length, batch_size), device=device) >= mask_rate).type(
        device=device, dtype=dtype
    )

In [None]:
torch.rand((3, 3))

In [None]:
kvae.eval()
observation_mask = create_continuous_mask(data.shape[0], mask_length, data.shape[1])
for i, data in enumerate(dataloader_test):
    data = (data > 0.5).to(dtype=dtype, device=device)
    elbo, info = kvae.elbo(
        xs=data,
        observation_mask=observation_mask,
        sample_control=SampleControl(),
    )
    break

In [None]:
info.keys()

In [None]:
filtered_images = (
    kvae.decoder(info["filter_as"].view(-1, 2))
    .mean.view(seq_length, batch_size, *image.shape[-3:])
    .cpu()
    .float()
    .detach()
    .numpy()
)
smoothed_images = (
    kvae.decoder(info["as_resampled"].view(-1, 2))
    .mean.view(seq_length, batch_size, *image.shape[-3:])
    .cpu()
    .float()
    .detach()
    .numpy()
)

In [None]:
idx = 0
cmap = plt.get_cmap("tab10")

with TemporaryDirectory() as dname:
    png_files = []
    for step, (image) in enumerate((data)):
        fig, axes = plt.subplots(figsize=(12, 4), nrows=1, ncols=3)
        fig.suptitle(f"$t = {step}$")

        image = (image > 0.5).cpu().float().detach().numpy()
        red_grad = LinearSegmentedColormap.from_list(
            "red_grad", [(1, 1, 1), (1, 0, 0)], N=256
        )
        black_grad = LinearSegmentedColormap.from_list(
            "black_grad", [(1, 1, 1), (0, 0, 0)], N=256
        )

        axes[0].imshow(
            image[idx][0], vmin=0, vmax=1, cmap=red_grad, aspect="equal", alpha=0.5
        )
        axes[0].imshow(
            filtered_images[step, idx, 0],
            vmin=0,
            vmax=1,
            cmap=black_grad,
            aspect="equal",
            alpha=0.5,
        )

        axes[1].imshow(
            image[idx][0], vmin=0, vmax=1, cmap=red_grad, aspect="equal", alpha=0.5
        )
        axes[1].imshow(
            smoothed_images[step, idx, 0],
            vmin=0,
            vmax=1,
            cmap=black_grad,
            aspect="equal",
            alpha=0.5,
        )

        axes[2].plot(
            info["as"][:, idx, 0].cpu().detach().numpy(),
            info["as"][:, idx, 1].cpu().detach().numpy(),
            ".-",
            color=cmap(0),
            label="Encoded",
        )

        axes[2].plot(
            info["filter_as"][:, idx, 0].cpu().detach().numpy(),
            info["filter_as"][:, idx, 1].cpu().detach().numpy(),
            ".-",
            color=cmap(1),
            label="Filtered",
        )

        axes[2].plot(
            info["as_resampled"][:, idx, 0].cpu().detach().numpy(),
            info["as_resampled"][:, idx, 1].cpu().detach().numpy(),
            ".-",
            color=cmap(2),
            label="Smoothed",
        )

        for key in ("as", "filter_as", "as_resampled"):
            axes[2].plot(
                info[key][step, idx, 0].cpu().detach().numpy(),
                info[key][step, idx, 1].cpu().detach().numpy(),
                "o",
                markersize=8,
                color="red",
                linestyle="none",
                zorder=10,
            )

        axes[2].plot(
            (observation_mask.unsqueeze(-1) * info["as"])[:, idx, 0]
            .cpu()
            .detach()
            .numpy(),
            (observation_mask.unsqueeze(-1) * info["as"])[:, idx, 1]
            .cpu()
            .detach()
            .numpy(),
            "s",
            color="black",
            label="Observed",
        )

        axes[0].set_title("from filtered $\\mathbf{z}$")
        axes[1].set_title("from smoothed $\\mathbf{z}$")
        axes[2].set_title("$\\mathbf{a}$ space")
        axes[2].legend()
        axes[2].grid()

        plt.tight_layout()
        png_file = os.path.join(dname, f"{step}.png")
        plt.savefig(png_file)
        png_files.append(png_file)
        plt.close()
    clip = ImageSequenceClip(png_files, fps=10)
    clip.write_videofile("trajectory.mp4", codec="libx264")