In [1]:
import glob
import json
import time
import os
import shutil

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim.lr_scheduler as lr_scheduler
from bouncing_ball.dataloaders.bouncing_data import BouncingBallDataLoader
from kalman_vae import KalmanVariationalAutoencoder
from natsort import natsorted
from sample_control import SampleControl
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
for path in glob.glob("checkpoints/bouncing_ball_*_"):
    shutil.rmtree(path)

In [3]:
warmup_epochs = 5
burn_in = 10
device = torch.device("cuda:0")
symmetrize_covariance = True
sample_control_train = SampleControl.training_defaults()
dtype = torch.float64
checkpoint_dir = (
    f"checkpoints/bouncing_ball_wo-vae_dtype-{dtype}_warmup-{warmup_epochs}_burnin-{burn_in}_"
)

In [4]:
os.makedirs(checkpoint_dir, exist_ok=True)

In [5]:
# fix random seeds for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [6]:
_dataloader_train = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/train"
)
_dataloader_test = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/test"
)

In [7]:
x_tensor = torch.linspace(-2, 2, 16)
y_tensor = torch.linspace(2, -2, 16)

In [8]:
x_tensor

tensor([-2.0000, -1.7333, -1.4667, -1.2000, -0.9333, -0.6667, -0.4000, -0.1333,
         0.1333,  0.4000,  0.6667,  0.9333,  1.2000,  1.4667,  1.7333,  2.0000])

In [9]:
def sequence_first_collate_fn(batch):
    data = torch.Tensor(np.stack(batch, axis=0))
    # data.shape: [batch size, sequence length, channels, height, width]
    # Reshape to [sequence length, batch size, channels, height, width]
    data = data.permute(1, 0, 2, 3, 4)
    weight_x = data.mean(-1)
    weight_x = (weight_x / weight_x.sum(-1).unsqueeze(-1)).squeeze(-2)
    weight_y = data.mean(-2)
    weight_y = (weight_y / weight_y.sum(-1).unsqueeze(-1)).squeeze(-2)
    
    data_x = (weight_x * x_tensor).sum(-1)
    data_y = (weight_y * y_tensor).sum(-1)

    return torch.stack([data_x, data_y], dim=-1)

In [10]:
dataloader_train = DataLoader(
    _dataloader_train,
    batch_size=128,
    shuffle=True,
    collate_fn=sequence_first_collate_fn,
)
dataloader_test = DataLoader(
    _dataloader_test, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)

In [11]:
for i, data in enumerate(dataloader_train):
    break

In [14]:
data[0, 0]

tensor([-0.5905,  1.3333])

In [None]:
kvae = (
    KalmanVariationalAutoencoder(
        image_size=(16,16), # dummy
        image_channels=1, # dummy
        a_dim=2,
        z_dim=4,
        K=3,
        decoder_type="bernoulli",
    )
    .to(dtype=dtype)
    .to(device)
)

In [None]:
optimizer = torch.optim.Adam(kvae.parameters(), lr=7e-3)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.85)

In [None]:
def find_latest_checkpoint_index(pattern):
    files = glob.glob(pattern)
    if files:
        return int(
            max(files, key=lambda x: int(x.split("-")[-1].split(".")[0]))
            .split("-")[-1]
            .split(".")[0]
        )
    return None


latest_index = find_latest_checkpoint_index(os.path.join(checkpoint_dir, "state-*.pth"))

if latest_index is not None:
    checkpoint = torch.load(os.path.join(checkpoint_dir, f"state-{latest_index}.pth"))
    kvae.load_state_dict(checkpoint["model_state_dict"], strict=False)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch_start = checkpoint["epoch"] + 1
    print("Loaded checkpoint at epoch {}".format(latest_index))
else:
    epoch_start = 0

In [None]:
p = tqdm(range(epoch_start, 80))
for epoch in p:
    kvae.train()
    learn_weight_model = epoch >= warmup_epochs
    losses = []
    for i, data in enumerate(dataloader_train):
        mask = torch.ones(data.shape[:2]).to(device)
        data = (data > 0.5).to(dtype=dtype).to(device)
        optimizer.zero_grad()
        elbo, info = kvae.elbo(
            as_=data,
            # observation_mask=mask,
            reconst_weight=0.3*50,
            regularization_weight=1.0*50,
            learn_weight_model=learn_weight_model,
            symmetrize_covariance=symmetrize_covariance,
            burn_in=burn_in,
            sample_control=sample_control_train,
        )
        loss = -elbo/50
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        p.set_description(
            f"Train Epoch {epoch}, Batch {i}/{len(dataloader_train)}, Loss {loss.item()}"
        )
    train_loss = sum(losses) / len(losses)

    # Test
    kvae.eval()
    losses = []
    for i, data in enumerate(dataloader_test):
        data = (data > 0.5).to(dtype=dtype).to(device)
        elbo, info = kvae.elbo(
            as_=data,
            reconst_weight=0.3*50,
            regularization_weight=1.0*50,
            symmetrize_covariance=symmetrize_covariance,
            sample_control=sample_control_train,
        )
        loss = -elbo/50
        losses.append(loss.item())
        p.set_description(
            f"Test Epoch {epoch}, Batch {i}/{len(dataloader_test)}, Loss {loss.item()}"
        )

    test_loss = sum(losses) / len(losses)

    if (epoch > 0) & (epoch % 20 == 0):
        scheduler.step()

    # Save
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": kvae.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "train_loss": train_loss,
            "test_loss": test_loss,
        },
        os.path.join(checkpoint_dir, f"state-{epoch}.pth"),
    )