In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import time
import wandb
import torch
import torch.optim.lr_scheduler as lr_scheduler

from dataloaders.bouncing_data import BouncingBallDataLoader
from torch.utils.data import DataLoader
from datetime import datetime

from Kalman_Filter import Kalman_Filter
from Kalman_VAE import KalmanVAE



In [2]:
train_dir = "/data2/users/lr4617/data/Bouncing_Ball/train"
test_dir = "/data2/users/lr4617/data/Bouncing_Ball/test"

train_dl = BouncingBallDataLoader(train_dir, images=True)
test_dl = BouncingBallDataLoader(test_dir, images=True)


In [3]:
x_tensor = torch.linspace(-2, 2, 16)
y_tensor = torch.linspace(2, -2, 16)

In [4]:
def sequence_first_collate_fn(batch):
    data = torch.Tensor(np.stack(batch, axis=0))
    # data.shape: [batch size, sequence length, channels, height, width]
    # Reshape to [sequence length, batch size, channels, height, width]
    weight_x = data.mean(-1)
    weight_x = (weight_x / weight_x.sum(-1).unsqueeze(-1)).squeeze(-2)
    weight_y = data.mean(-2)
    weight_y = (weight_y / weight_y.sum(-1).unsqueeze(-1)).squeeze(-2)
    
    data_x = (weight_x * x_tensor).sum(-1)
    data_y = (weight_y * y_tensor).sum(-1)

    return torch.stack([data_x, data_y], dim=-1)

In [5]:
dataloader_train = DataLoader(
    train_dl,
    batch_size=128,
    shuffle=True,
    collate_fn=sequence_first_collate_fn,
)
dataloader_test = DataLoader(
    test_dl, batch_size=128, shuffle=True, collate_fn=sequence_first_collate_fn
)

In [6]:
T = 50
n_channels_in = None
dim = None
dim_a = 2
dim_z = 4
K = 3
use_MLP = True

device = 0
dtype = torch.float32

# load model
nonlinear_ssm = KalmanVAE(n_channels_in,
                          dim,
                          dim_a, 
                          dim_z, 
                          K, 
                          T=T, 
                          use_MLP=use_MLP,
                          dtype=dtype, 
                          train_VAE=False).to('cuda:' + str(device)).to(dtype=dtype)

In [7]:
lr = 0.001
gamma_lr_schedule = 0.85

optimizer = torch.optim.Adam(nonlinear_ssm.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma_lr_schedule)

num_epochs = 50

In [8]:
output_folder = '/data2/users/lr4617/KalmanVAE/results/nonlinear_SSM/'

now = datetime.now()
run_name = 'run_' + now.strftime("%Y_%m_%d_%H_%M_%S")
save_filename = os.path.join(output_folder, '', run_name, '')
if not os.path.isdir(save_filename):
    os.makedirs(save_filename)

run = wandb.init(project='nonlinearSSM', 
                 config={'learning-rate': str(lr), 
                         'num_epochs': str(num_epochs)})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlapo0510[0m ([33minformation-theoretic-view-of-bn[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
def plot_dynamics(train_loader, alpha, epoch, output_folder, dtype):
    
    for n, sample in enumerate(train_loader, 1):
        if n > 1: 
            break
        for i in range(1):
            save_filename_sample = os.path.join(output_folder, '', 'epoch_{}'.format(epoch), 'sample_{}'.format(i))
            if not os.path.isdir(save_filename_sample):
                os.makedirs(save_filename_sample)

            single_sample = sample[i]
            weights = alpha[i]

            for t in range(T):
                fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
                fig.suptitle(f"$t = {t}$")
                
                axes[0].plot(single_sample[t][0], single_sample[t][1], "o")
                axes[0].set_adjustable('box') 
                axes[0].set_title(r"Observation $\mathbf{a}_t$")
                axes[0].set_xlim([-2,2])
                axes[0].set_ylim([-2,2])

                axes[1].bar(["k=0", "k=1", "k=2"], weights[t].detach().cpu().numpy())
                axes[1].set_title(r"weight $\mathbf{k}_t$")

                fig.savefig(os.path.join(save_filename_sample, 'weight-{}.png'.format(t)))
                plt.close()
    

In [10]:
dyn_save_filename = os.path.join(
    save_filename, '', 'visualize_dynamics', '', 'training', '')
if not os.path.isdir(dyn_save_filename):
    os.makedirs(dyn_save_filename)

start = time.time()
log_list = []

train_dyn_net = False

for epoch in range(num_epochs):

    # train
    loss_epoch = 0.
    idv_losses = {'LGSSM observation log likelihood': 0,
                  'LGSSM tranisition log likelihood': 0,
                  'LGSSM tranisition log posterior': 0}

    for n, sample in enumerate(dataloader_train, 1):

        optimizer.zero_grad()

        sample = sample.to(dtype).to('cuda:' + str(device))

        if epoch >= 5:
            train_dyn_net = True

        alpha, loss, loss_dict = nonlinear_ssm.calculate_loss(sample, train_dyn_net=train_dyn_net)

        loss.backward()
        optimizer.step()

        loss_epoch += loss

        for key in idv_losses.keys():
            idv_losses[key] += loss_dict[key]

        alphas = alpha.detach().cpu()
        
    loss_epoch = loss_epoch/len(dataloader_train)
    for key in idv_losses.keys():
        idv_losses[key] = idv_losses[key]/len(dataloader_train)

    run.log(loss_dict)
    
    # logistics
    for key in idv_losses.keys():
        idv_losses[key] = idv_losses[key]/len(dataloader_train)
    if epoch % 20 == 0 and epoch > 0:
        scheduler.step()
    end = time.time()
    log = 'epoch = {}, loss_train = {}, time = {}'.format(
        epoch+1, loss_epoch, end-start)
    start = end
    print(log)
    log_list.append(log + '\n')

    # plot dynamics
    plot_dynamics(dataloader_train,
                  alphas,
                  epoch,
                  output_folder=dyn_save_filename,
                  dtype=dtype)

    # save checkpoints
    if epoch % 10 == 0 or epoch == num_epochs-1:
        with open(save_filename + '/nonlinear_ssm' + str(epoch+1) + '.pt', 'wb') as f:
            torch.save(nonlinear_ssm.state_dict(), f)

    # save training log
    with open(save_filename + '/training.cklog', "a+") as log_file:
        log_file.writelines(log_list)
        log_list.clear()


epoch = 1, loss_train = 3326.27734375, time = 12.469932794570923
epoch = 2, loss_train = 1030.6702880859375, time = 18.47687792778015
epoch = 3, loss_train = 749.6433715820312, time = 18.250099182128906
epoch = 4, loss_train = 611.7440795898438, time = 18.143516540527344
epoch = 5, loss_train = 523.51025390625, time = 17.080658674240112
epoch = 6, loss_train = 391.158935546875, time = 18.547069311141968


_LinAlgError: linalg.cholesky: (Batch element 768): The factorization could not be completed because the input is not positive-definite (the leading minor of order 4 is not positive-definite).