In [42]:
import os
import sys

import matplotlib.pyplot as plt
from hydra import compose, initialize
    
import math
import torch
import torch.nn as nn

import pytorch_lightning as lightning
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint 

import xfads.utils as utils

from xfads.ssm_modules.dynamics import DenseGaussianDynamics
from xfads.ssm_modules.likelihoods import GaussianLikelihood
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
from xfads.smoothers.nonlinear_smoother_causal import NonlinearFilter, LowRankNonlinearStateSpaceModel

In [44]:
torch.cuda.empty_cache()

In [46]:
"""config"""

initialize(version_base=None, config_path="", job_name="lds")
cfg = compose(config_name="config")
lightning.seed_everything(cfg.seed, workers=True)
torch.set_default_dtype(torch.float32)

ValueError: GlobalHydra is already initialized, call GlobalHydra.instance().clear() if you want to re-initialize

In [48]:
"""generate data -- 2d oscillator with decay"""

n_trials = 500
n_neurons = 100
n_time_bins = 250

C = utils.FanInLinear(cfg.n_latents, n_neurons, device=cfg.device).requires_grad_(False)
Q_diag = 1e-3 * torch.ones(cfg.n_latents, device=cfg.device)
Q_0_diag = 2. * torch.ones(cfg.n_latents, device=cfg.device)
R_diag = 1e-1 * torch.ones(n_neurons, device=cfg.device)
m_0 = torch.zeros(cfg.n_latents, device=cfg.device)

mean_fn = utils.VdpDynamicsModel()

z = utils.sample_gauss_z(mean_fn, Q_diag, m_0, Q_0_diag, n_trials, n_time_bins)
y = C(z) + torch.sqrt(R_diag) * torch.randn((n_trials, n_time_bins, n_neurons), device=cfg.device)
y = y.detach()

y_train, z_train = y[:2*n_trials//3], z[:2*n_trials//3]
y_valid, z_valid = y[2*n_trials//3:], z[2*n_trials//3:]

In [72]:
y_train_dataset = torch.utils.data.TensorDataset(y_train,)
y_valid_dataset = torch.utils.data.TensorDataset(y_valid,)

train_dataloader = torch.utils.data.DataLoader(
    y_train_dataset,
    batch_size=cfg.batch_sz,
    shuffle=True,
    num_workers=9
)

valid_dataloader = torch.utils.data.DataLoader(
    y_valid_dataset,
    batch_size=cfg.batch_sz,
    shuffle=True,
)

In [52]:
"""likelihood pdf"""

likelihood_pdf = GaussianLikelihood(C, n_neurons, R_diag, device=cfg.device)

In [54]:
"""dynamics module"""

dynamics_fn = utils.build_gru_dynamics_function(cfg.n_latents, cfg.n_hidden_dynamics, device=cfg.device)
dynamics_mod = DenseGaussianDynamics(dynamics_fn, cfg.n_latents, Q_diag, device=cfg.device)

In [56]:
"""initial condition"""

initial_condition_pdf = DenseGaussianInitialCondition(cfg.n_latents, m_0, Q_0_diag, device=cfg.device)

In [58]:
"""local/backward encoder"""

backward_encoder = BackwardEncoderLRMvn(cfg.n_latents, cfg.n_hidden_backward, cfg.n_latents,
                                        rank_local=cfg.rank_local, rank_backward=cfg.rank_backward,
                                        device=cfg.device)

local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
                                  device=cfg.device, dropout=cfg.p_local_dropout)

nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)

In [60]:
"""sequence vae"""

ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
                                      local_encoder, nl_filter, device=cfg.device)

In [62]:
"""lightning"""

seq_vae = LightningNonlinearSSM(ssm, cfg)

csv_logger = CSVLogger('logs/', name=f'r_y_{cfg.rank_local}_r_b_{cfg.rank_backward}', version='noncausal')
ckpt_callback = ModelCheckpoint(save_top_k=3, monitor='valid_loss', mode='min', dirpath='ckpts/',
                                filename='{epoch:0}_{valid_loss}')

trainer = lightning.Trainer(max_epochs=cfg.n_epochs,
                            gradient_clip_val=1.0,
                            default_root_dir='lightning/',
                            callbacks=[ckpt_callback],
                            logger=csv_logger
                            )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [85]:
trainer.fit(model=seq_vae, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
torch.save(ckpt_callback.best_model_path, 'ckpts/best_model_path.pt')


  | Name | Type                            | Params
---------------------------------------------------------
0 | ssm  | LowRankNonlinearStateSpaceModel | 41.3 K
---------------------------------------------------------
41.0 K    Trainable params
300       Non-trainable params
41.3 K    Total params
0.165     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

In [76]:
trainer.test(dataloaders=valid_dataloader, ckpt_path='last')

TypeError: `Trainer.test()` requires a `LightningModule` when it hasn't been passed in a previous run