In [1]:
import torch
import torch.nn as nn
import time
import wandb
import matplotlib.pyplot as plt 
import random
import numpy as np
import os

from Kalman_VAE import KalmanVAE
from datetime import datetime

from dataloaders.bouncing_data import BouncingBallDataLoader
from torch.utils.data import DataLoader

In [2]:
train_dir = os.path.join('/data2/users/lr4617/data/Bouncing_Ball/', 'train')
test_dir = os.path.join('/data2/users/lr4617/data/Bouncing_Ball/', 'test')
train_dl = BouncingBallDataLoader(train_dir, images=True)
test_dl = BouncingBallDataLoader(test_dir, images=True)

data = test_dl[1]

In [4]:
image_size = data.shape[2:]
n_channels_in = data.shape[1]
seq_len = data.shape[0]

dim_a = 2
dim_z = 4
K = 3

kvae = KalmanVAE(n_channels_in,
                 image_size[0],
                 dim_a, 
                 dim_z, 
                 K, 
                 T=seq_len, 
                 recon_scale=0.3).cuda()

root_model_dir = '/data2/users/lr4617/KalmanVAE/results/Kalman_VAE/Bouncing_Ball/run_2023_11_04_12_59_26/'
model_path = os.path.join(root_model_dir, '', 'kvae100.pt')
kvae.load_state_dict(torch.load(model_path, map_location=torch.device('cuda:0')))
kvae.eval()

for name, param in kvae.named_parameters():
    if param.requires_grad:
        print (name)

kalman_filter.A
kalman_filter.C
kalman_filter.a_0
kalman_filter.dyn_net.lstm.weight_ih_l0
kalman_filter.dyn_net.lstm.weight_hh_l0
kalman_filter.dyn_net.lstm.bias_ih_l0
kalman_filter.dyn_net.lstm.bias_hh_l0
kalman_filter.dyn_net.lstm.weight_ih_l1
kalman_filter.dyn_net.lstm.weight_hh_l1
kalman_filter.dyn_net.lstm.bias_ih_l1
kalman_filter.dyn_net.lstm.bias_hh_l1
kalman_filter.dyn_net.linear.weight
kalman_filter.dyn_net.linear.bias
encoder.conv_modules.0.weight
encoder.conv_modules.0.bias
encoder.conv_modules.1.weight
encoder.conv_modules.1.bias
encoder.conv_modules.2.weight
encoder.conv_modules.2.bias
encoder.to_mean.weight
encoder.to_mean.bias
encoder.to_std.weight
encoder.to_std.bias
decoder.to_conv.weight
decoder.to_conv.bias
decoder.conv_tranpose_modules.0.weight
decoder.conv_tranpose_modules.0.bias
decoder.conv_tranpose_modules.1.weight
decoder.conv_tranpose_modules.1.bias
decoder.conv_tranpose_modules.2.weight
decoder.conv_tranpose_modules.2.bias
decoder.to_mean.weight
decoder.to_me

In [5]:
batched_data = torch.Tensor(data).unsqueeze(0).to('cuda:0')

a_mean, a_std = kvae.encoder(batched_data.view(-1, 1, 16, 16))
a_sample = (a_mean + a_std*torch.normal(mean=torch.zeros_like(a_mean))).view(1, seq_len, dim_a)

a_0 = kvae.kalman_filter.a_0.unsqueeze(0).unsqueeze(1).repeat(1, 1, 1)
joint_code_obs = torch.cat([a_0, a_sample], dim=1)
k_weights = kvae.kalman_filter.dyn_net(joint_code_obs[:, :-1, :])
x_hat, _ = kvae.decoder(a_sample)

In [6]:
root_analysis_dir = os.path.join(root_model_dir, '', 'dyn_analysis')
if not os.path.isdir(root_analysis_dir):
    os.mkdir(root_analysis_dir)

for step, (image, reconstruction, weight) in enumerate(zip(batched_data.squeeze(0).cpu(), x_hat.detach().cpu().numpy(), k_weights.squeeze(0))):
    image = image > 0.5
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
    fig.suptitle(f"$t = {step}$")

    axes[0].imshow(image[0], vmin=0, vmax=1, cmap="Greys", aspect='equal')
    axes[0].set_adjustable('box') 
    axes[1].imshow(reconstruction[0], vmin=0, vmax=1, cmap="Greys", aspect='equal')
    axes[2].bar(["0", "1", "2"], weight.detach().cpu().numpy())
    axes[2].set_ylim(0, 1)
    axes[0].set_title(r"image $\mathbf{x}_t$")
    axes[1].set_title(r"reconstruction $\hat{\mathbf{x}}_t$")
    axes[2].set_title(r"weight $\mathbf{k}_t$")
    pos_img = axes[0].get_position()
    pos_bar = axes[2].get_position()
    axes[2].set_position([pos_bar.x0, pos_img.y0, pos_bar.width, pos_img.height])
    
    fig.savefig(os.path.join(root_analysis_dir, 'weight-{}.png'.format(step)))
    plt.close()

In [7]:
!ffmpeg -framerate 10 -i /data2/users/lr4617/KalmanVAE/results/Kalman_VAE/Bouncing_Ball/run_2023_11_04_12_59_26/dyn_analysis/weight-%d.png -c:v libopenh264 -r 30 -pix_fmt yuv420p /data2/users/lr4617/KalmanVAE/results/Kalman_VAE/Bouncing_Ball/run_2023_11_04_12_59_26/dyn_analysis/weight.mp4 -y

ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/vol/bitbucket/lr4617/anaconda3/envs/py38_pytorch --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libswscale      5.  7.100 /  5.  7.100
  libswresample   3.  7.100 /  3.  7.100
Input #0, image2, from '/data2/users/lr4617/KalmanVAE/results/Kalman_VAE/Bouncing_Ball/run_2023_11_04_12_59_26/dyn_analysis/weight-

In [8]:
from IPython.display import Video

Video("/data2/users/lr4617/KalmanVAE/results/Kalman_VAE/Bouncing_Ball/run_2023_11_04_12_59_26/dyn_analysis/weight.mp4")