In [1]:
import glob
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim.lr_scheduler as lr_scheduler
from bouncing_ball.dataloaders.bouncing_data import BouncingBallDataLoader
from kalman_vae import KalmanVariationalAutoencoder
from natsort import natsorted
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
checkpoint_dir = "checkpoints/bouncing_ball_double"
warmup_epochs = 3
device = torch.device("cuda:2")
symmetrize_covariance = True
dtype = torch.float64

In [3]:
os.makedirs(checkpoint_dir, exist_ok=True)

In [4]:
# fix random seeds for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
_dataloader_train = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/train"
)
_dataloader_test = BouncingBallDataLoader(
    root_dir="bouncing_ball/datasets/bouncing-ball/test"
)

In [6]:
def sequence_first_collate_fn(batch):
    data = torch.Tensor(np.stack(batch, axis=0))
    # data.shape: [batch size, sequence length, channels, height, width]
    # Reshape to [sequence length, batch size, channels, height, width]
    data = data.permute(1, 0, 2, 3, 4)
    return data

In [7]:
dataloader_train = DataLoader(
    _dataloader_train, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)
dataloader_test = DataLoader(
    _dataloader_test, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)

In [8]:
for i, data in enumerate(dataloader_train):
    print(data.shape)
    # To Float32
    data = (data > 0.5).float()
    break

torch.Size([50, 128, 1, 16, 16])


In [9]:
kvae = KalmanVariationalAutoencoder(
    image_size=data.shape[3:],
    image_channels=data.shape[2],
    a_dim=2,
    z_dim=4,
    K=3,
    decoder_type="bernoulli",
).to(dtype=dtype).to(device)

In [10]:
optimizer = torch.optim.Adam(kvae.parameters(), lr=7e-3)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.85)

In [11]:
def find_latest_checkpoint_index(pattern):
    files = glob.glob(pattern)
    if files:
        return int(
            max(files, key=lambda x: int(x.split("-")[-1].split(".")[0]))
            .split("-")[-1]
            .split(".")[0]
        )
    return None


latest_index = find_latest_checkpoint_index(os.path.join(checkpoint_dir, "state-*.pth"))

if latest_index is not None:
    checkpoint = torch.load(os.path.join(checkpoint_dir, f"state-{latest_index}.pth"))
    kvae.load_state_dict(checkpoint["model_state_dict"], strict=False)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch_start = checkpoint["epoch"] + 1
    print("Loaded checkpoint at epoch {}".format(latest_index))
else:
    epoch_start = 0

In [12]:
p = tqdm(range(epoch_start, 100))
for epoch in p:
    kvae.train()
    learn_weight_model = epoch >= warmup_epochs
    losses = []
    for i, data in enumerate(dataloader_train):
        data = (data > 0.5).to(dtype=dtype).to(device)
        optimizer.zero_grad()
        elbo, info = kvae.elbo(data, learn_weight_model=learn_weight_model, symmetrize_covariance=symmetrize_covariance)
        loss = -elbo
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        p.set_description(
            f"Train Epoch {epoch}, Batch {i}/{len(dataloader_train)}, Loss {loss.item()}"
        )
    train_loss = sum(losses) / len(losses)

    # Test
    kvae.eval()
    losses = []
    for i, data in enumerate(dataloader_test):
        data = (data > 0.5).to(dtype=dtype).to(device)
        elbo, info = kvae.elbo(data, symmetrize_covariance=symmetrize_covariance)
        loss = -elbo
        losses.append(loss.item())
        p.set_description(
            f"Test Epoch {epoch}, Batch {i}/{len(dataloader_test)}, Loss {loss.item()}"
        )

    test_loss = sum(losses) / len(losses)
    
    if (epoch > 0) & (epoch % 20 == 0):
        scheduler.step()

    # Save
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": kvae.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": train_loss,
            "test_loss": test_loss,
        },
        os.path.join(checkpoint_dir, f"state-{epoch}.pth"),
    )

Test Epoch 99, Batch 7/8, Loss -42700.16918306143: 100%|██████████████| 100/100 [58:20<00:00, 35.00s/it]


In [13]:
torch.linalg.svd(kvae.state_space_model.mat_A_K[0])

torch.return_types.linalg_svd(
U=tensor([[ 0.8003, -0.2298, -0.5529,  0.0328],
        [ 0.2916, -0.3158,  0.5936,  0.6803],
        [-0.3086, -0.9204, -0.0777, -0.2273],
        [-0.4235,  0.0190, -0.5795,  0.6960]], device='cuda:2',
       dtype=torch.float64, grad_fn=<LinalgSvdBackward0>),
S=tensor([4.0181, 2.8965, 2.3432, 0.0133], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
Vh=tensor([[ 0.5411,  0.5638, -0.0017,  0.6240],
        [ 0.5422, -0.1799,  0.7618, -0.3055],
        [ 0.4475, -0.7550, -0.3792,  0.2931],
        [-0.4615, -0.2824,  0.5252,  0.6568]], device='cuda:2',
       dtype=torch.float64, grad_fn=<LinalgSvdBackward0>))

In [14]:
torch.linalg.svd(kvae.state_space_model.mat_A_K[1])

torch.return_types.linalg_svd(
U=tensor([[ 0.8264,  0.5620, -0.0350, -0.0105],
        [ 0.4371, -0.6724, -0.5489,  0.2355],
        [-0.0199, -0.0126, -0.3945, -0.9186],
        [-0.3546,  0.4815, -0.7361,  0.3172]], device='cuda:2',
       dtype=torch.float64, grad_fn=<LinalgSvdBackward0>),
S=tensor([3.2200, 2.3092, 1.4143, 0.6962], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
Vh=tensor([[ 0.7228, -0.6866,  0.0593,  0.0509],
        [ 0.0519,  0.1600,  0.9230,  0.3461],
        [ 0.1078,  0.0728,  0.3317, -0.9344],
        [ 0.6806,  0.7055, -0.1859,  0.0675]], device='cuda:2',
       dtype=torch.float64, grad_fn=<LinalgSvdBackward0>))

In [15]:
torch.linalg.svd(kvae.state_space_model.mat_A_K[2])

torch.return_types.linalg_svd(
U=tensor([[ 0.6829,  0.5509,  0.2139, -0.4294],
        [ 0.0075, -0.6561,  0.4498, -0.6060],
        [ 0.2669, -0.1049,  0.7097,  0.6435],
        [-0.6799,  0.5050,  0.4984, -0.1853]], device='cuda:2',
       dtype=torch.float64, grad_fn=<LinalgSvdBackward0>),
S=tensor([4.1904, 2.3722, 0.8549, 0.0771], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
Vh=tensor([[ 0.6636, -0.5675,  0.2022,  0.4435],
        [-0.2826,  0.0500,  0.9566,  0.0508],
        [ 0.4698, -0.0565,  0.1874, -0.8608],
        [ 0.5089,  0.8199,  0.0946,  0.2446]], device='cuda:2',
       dtype=torch.float64, grad_fn=<LinalgSvdBackward0>))

In [16]:
kvae.state_space_model.mat_C_K

Parameter containing:
tensor([[[-0.2851, -0.2979, -0.2173, -0.3037],
         [ 1.3102, -0.5970, -0.8599, -1.7043]],

        [[-1.6409, -1.0691, -1.2499,  4.1112],
         [ 1.0235,  1.2006,  0.4771, -1.6253]],

        [[-3.8794, -2.4226, -1.0353,  2.9258],
         [ 0.6643,  0.2844,  0.6819, -0.2459]]], device='cuda:2',
       dtype=torch.float64, requires_grad=True)

In [17]:
torch.linalg.svd(kvae.state_space_model.mat_Q)

torch.return_types.linalg_svd(
U=tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
S=tensor([0.0010, 0.0010, 0.0010, 0.0010], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
Vh=tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>))

In [18]:
torch.linalg.svd(kvae.state_space_model.mat_R)

torch.return_types.linalg_svd(
U=tensor([[ 0.0076,  1.0000],
        [ 1.0000, -0.0076]], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
S=tensor([1.0090e+02, 1.0000e-03], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>),
Vh=tensor([[ 0.0076,  1.0000],
        [ 1.0000, -0.0076]], device='cuda:2', dtype=torch.float64,
       grad_fn=<LinalgSvdBackward0>))