In [1]:
import os
import json
import glob

from natsort import natsorted
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from kalman_vae import KalmanVariationalAutoencoder
from bouncing_ball.dataloaders.bouncing_data import BouncingBallDataLoader

In [2]:
# fix random seeds for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
dataloader_train = BouncingBallDataLoader(root_dir='bouncing_ball/datasets/bouncing-ball/train')
dataloader_test = BouncingBallDataLoader(root_dir='bouncing_ball/datasets/bouncing-ball/test')

In [4]:
dataloader_train = DataLoader(dataloader_train, batch_size=64, shuffle=True)
dataloader_test = DataLoader(dataloader_test, batch_size=64, shuffle=True)

In [5]:
for i, data in enumerate(dataloader_train):
    print(data.shape)
    # To Float32
    data = (data > 0.5).float()
    break

torch.Size([64, 50, 1, 16, 16])


In [6]:
kvae = KalmanVariationalAutoencoder(image_size = data.shape[3:], image_channels=data.shape[2], a_dim=2, z_dim=4, K=3, decoder_type='bernoulli')

In [7]:
optimizer = torch.optim.Adam(kvae.parameters(), lr=1e-3)

In [8]:
def find_latest_checkpoint_index(pattern):
    files = glob.glob(pattern)
    if files:
        return int(max(files, key=lambda x: int(x.split('-')[-1].split('.')[0])).split('-')[-1].split('.')[0])
    return None

latest_index = find_latest_checkpoint_index('checkpoints/bouncing_ball/state-*.pth')

if latest_index is not None:
    checkpoint = torch.load('checkpoints/bouncing_ball/state-{}.pth'.format(latest_index))
    kvae.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_start = checkpoint['epoch'] + 1
    print('Loaded checkpoint at epoch {}'.format(latest_index))
else:
    epoch_start = 0

In [None]:
p = tqdm(range(epoch_start, 50))
for epoch in p:
    kvae.train()
    losses = []
    for i, data in enumerate(dataloader_train):
        data = (data > 0.5).float()
        optimizer.zero_grad()
        elbo, info = kvae.elbo(data)
        loss = -elbo
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        p.set_description(f'Train Epoch {epoch}, Batch {i}/{len(dataloader_train)}, Loss {loss.item()}')
    train_loss = sum(losses) / len(losses)

    # Test
    kvae.eval()
    losses = []
    for i, data in enumerate(dataloader_test):
        data = (data > 0.5).float()
        elbo, info = kvae.elbo(data)
        loss = -elbo
        losses.append(loss.item())
        p.set_description(f'Test Epoch {epoch}, Batch {i}/{len(dataloader_test)}, Loss {loss.item()}')

    test_loss = sum(losses) / len(losses)

    # Save
    torch.save({
        'epoch': epoch,
        'model_state_dict': kvae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'test_loss': test_loss
    }, f'checkpoints/bouncing_ball/state-{epoch}.pth')

Train Epoch 9, Batch 30/79, Loss 51532.46875:  18%|▏| 9/50 [15:48<1:06:19, 97.06

In [None]:
train_losses = []
test_losses = []

for file in natsorted(glob.glob('checkpoints/bouncing_ball/state-*.pth')):
    checkpoint = torch.load(file)
    train_losses.append(checkpoint['train_loss'])
    test_losses.append(checkpoint['test_loss'])
    