In [1]:
import os

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from kalman_vae import KalmanVariationalAutoencoder
from bouncing_ball.dataloaders.bouncing_data import BouncingBallDataLoader

In [2]:
dataloader_train = BouncingBallDataLoader(root_dir='bouncing_ball/datasets/bouncing-ball/train')
dataloader_test = BouncingBallDataLoader(root_dir='bouncing_ball/datasets/bouncing-ball/test')

In [3]:
dataloader_train = DataLoader(dataloader_train, batch_size=64, shuffle=True)
dataloader_test = DataLoader(dataloader_test, batch_size=64, shuffle=True)

In [4]:
for i, data in enumerate(dataloader_train):
    print(data.shape)
    # To Float32
    data = (data > 0.5).float()
    break

torch.Size([64, 50, 1, 16, 16])


In [5]:
kvae = KalmanVariationalAutoencoder(image_size = data.shape[3:], image_channels=data.shape[2], a_dim=2, z_dim=4, K=3, decoder_type='bernoulli')

In [6]:
optimizer = torch.optim.Adam(kvae.parameters(), lr=1e-3)

In [7]:
os.makedirs('checkpoints/bouncing_ball', exist_ok=True)
losses = []
p = tqdm(range(10))
for epoch in p:
    for i, data in enumerate(dataloader_train):
        data = (data > 0.5).float()
        optimizer.zero_grad()
        elbo, info = kvae.elbo(data)
        loss = -elbo
        loss.backward()
        optimizer.step()
        p.set_description("Epoch: %d, Loss: %.4f" % (epoch, loss.item()))
    losses.append(loss.item())

    torch.save(kvae.state_dict(), 'checkpoints/bouncing_ball/kvae-%d.pth' % epoch)

  0%|                                                    | 0/10 [00:00<?, ?it/s]

> [0;32m/Users/naoki/Documents/PhD/kalman-vae/ssm.py[0m(205)[0;36mstate_transition_log_likelihood[0;34m()[0m
[0;32m    203 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    204 [0;31m[0;34m[0m[0m
[0m[0;32m--> 205 [0;31m        [0mstate_transition_log_likelihood[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    206 [0;31m[0;34m[0m[0m
[0m[0;32m    207 [0;31m        [0;32mfor[0m [0mt[0m [0;32min[0m [0mrange[0m[0;34m([0m[0msequence_length[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
[1;32m    192 [0m    [0;32mdef[0m [0mstate_transition_log_likelihood[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mzs[0m[0;34m,[0m [0mmat_As[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m    193 [0m        [0msequence_length[0m[0;34m,[0m [0mbatch_size[0m[0;34m,[0m [0m_[0m[0;34m,[0m [0m_[0m [0;34

  0%|                                                    | 0/10 [00:37<?, ?it/s]
