In [1]:
import glob
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from bouncing_ball.dataloaders.bouncing_data import BouncingBallDataLoader
from kalman_vae import KalmanVariationalAutoencoder
from natsort import natsorted
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
checkpoint_dir = "checkpoints/bouncing_ball"
warmup_epochs = 50
device = torch.device("cpu")

In [3]:
os.makedirs(checkpoint_dir, exist_ok=True)

In [4]:
# fix random seeds for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
_dataloader_train = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/train"
)
_dataloader_test = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/test"
)

In [6]:
def sequence_first_collate_fn(batch):
    data = torch.Tensor(np.stack(batch, axis=0))
    # data.shape: [batch size, sequence length, channels, height, width]
    # Reshape to [sequence length, batch size, channels, height, width]
    data = data.permute(1, 0, 2, 3, 4)
    return data

In [7]:
dataloader_train = DataLoader(
    _dataloader_train, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)
dataloader_test = DataLoader(
    _dataloader_test, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)

In [8]:
for i, data in enumerate(dataloader_train):
    print(data.shape)
    # To Float32
    data = (data > 0.5).float()
    break

torch.Size([50, 128, 1, 16, 16])


In [9]:
kvae = KalmanVariationalAutoencoder(
    image_size=data.shape[3:],
    image_channels=data.shape[2],
    a_dim=2,
    z_dim=4,
    K=3,
    decoder_type="bernoulli",
).to(device)

In [10]:
optimizer = torch.optim.Adam(kvae.parameters(), lr=1e-3)

In [11]:
def find_latest_checkpoint_index(pattern):
    files = glob.glob(pattern)
    if files:
        return int(
            max(files, key=lambda x: int(x.split("-")[-1].split(".")[0]))
            .split("-")[-1]
            .split(".")[0]
        )
    return None


latest_index = find_latest_checkpoint_index(os.path.join(checkpoint_dir, "state-*.pth"))

if latest_index is not None:
    checkpoint = torch.load(os.path.join(checkpoint_dir, f"state-{latest_index}.pth"))
    kvae.load_state_dict(checkpoint["model_state_dict"], strict=False)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch_start = checkpoint["epoch"] + 1
    print("Loaded checkpoint at epoch {}".format(latest_index))
else:
    epoch_start = 0

Loaded checkpoint at epoch 56


In [None]:
p = tqdm(range(epoch_start, 100))
for epoch in p:
    kvae.train()
    learn_weight_model = epoch >= warmup_epochs
    losses = []
    for i, data in enumerate(dataloader_train):
        data = (data > 0.5).float().to(device)
        optimizer.zero_grad()
        elbo, info = kvae.elbo(data, learn_weight_model=learn_weight_model)
        loss = -elbo
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        p.set_description(
            f"Train Epoch {epoch}, Batch {i}/{len(dataloader_train)}, Loss {loss.item()}"
        )
    train_loss = sum(losses) / len(losses)

    # Test
    kvae.eval()
    losses = []
    for i, data in enumerate(dataloader_test):
        data = (data > 0.5).float().to(device)
        elbo, info = kvae.elbo(data)
        loss = -elbo
        losses.append(loss.item())
        p.set_description(
            f"Test Epoch {epoch}, Batch {i}/{len(dataloader_test)}, Loss {loss.item()}"
        )

    test_loss = sum(losses) / len(losses)

    # Save
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": kvae.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": train_loss,
            "test_loss": test_loss,
        },
        os.path.join(checkpoint_dir, f"state-{epoch}.pth"),
    )

Train Epoch 58, Batch 30/40, Loss 23893.365234375:   2%|▊                                  | 1/43 [03:56<1:33:22, 133.39s/it]