# Latent-ODE to generate embeddings

## Importing Libraries

In [1]:
import os
import sys
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot
import matplotlib.pyplot as plt

import time
import datetime
import numpy as np
import pandas as pd
from random import SystemRandom
from sklearn import model_selection

import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim

import lib.utils as utils
from lib.plotting import *

from lib.ode_rnn import *
from lib.create_latent_ode_model import create_LatentODE_model
from lib.parse_datasets import parse_datasets
from lib.ode_func import ODEFunc, ODEFunc_w_Poisson
from lib.diffeq_solver import DiffeqSolver

from lib.utils import compute_loss_all_batches

Couldn't import umap


## Keyword Arguments for Running the model

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
params = {
    "train_miss_rate": 40,
    "test_miss_rate": 60,
    "batch_size": 40,
    "num_workers": 2,
    "save": "experiments",
    "load": 3,
    "n": 8000,
    "classif": False,
    "latents": 3,
    "rec_dims": 5,
    "poisson": False,
    "gen_layers": 2,
    "rec_layers": 2,
    "units": 20,
    "gru_units":20,
    "z0_encoder": "odernn",
    "lr": 1e-2,
    "epochs": 20,
    "chkpt": 5
}
args = dotdict(params)

## Setting files and configurations

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed=42)
np.random.seed(seed=10)

# file_name = os.path.basename(__file__)[:-3]
save_path = args.save
utils.makedirs(save_path)

experimentID = args.load
if experimentID is None:
    # Make a new experiment ID
    experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID))
if not os.path.exists(ckpt_path):
    os.makedirs(ckpt_path,exist_ok=True)

## Loading the data

In [5]:
from lib.preprocess_data import preprocess

In [6]:
SAMPLES = 1000
TIMESTAMPS = 500

In [7]:
start = time.time()
print("Sampling dataset of {} training examples".format(args.n))
utils.makedirs("results/")

data_obj = preprocess(args)
input_dim = data_obj["input_dim"]

classif_per_tp = data_obj["classif_per_tp"] if "classif_per_tp" in data_obj else False
n_labels = -1
if args.classif:
    if ("n_labels" in data_obj): n_labels = data_obj["n_labels"]
    else: raise Exception("Please provide number of labels for classification task")

Sampling dataset of 8000 training examples


  data_norm = np.nan_to_num((data - data_min) / (data_max - data_min))


In [8]:
data_obj

{'dataset_obj': <torch.utils.data.dataset.TensorDataset at 0x273168bfd60>,
 'train_dataloader': <torch.utils.data.dataloader.DataLoader at 0x27334916290>,
 'test_dataloader': <torch.utils.data.dataloader.DataLoader at 0x27334915de0>,
 'input_dim': 4,
 'n_train_batches': 20,
 'n_test_batches': 5,
 'classif_per_tp': False,
 'n_labels': 1}

## Creating the model

In [9]:
obsrv_std = 0.01
obsrv_std = torch.Tensor([obsrv_std]).to(device)

z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))

model = create_LatentODE_model(
    args, input_dim, z0_prior, obsrv_std, device, 
    classif_per_tp = classif_per_tp, n_labels = n_labels
)

if args.viz: viz = Visualizations(device)

In [10]:
model

LatentODE(
  (encoder_z0): Encoder_z0_ODE_RNN(
    (GRU_update): GRU_unit(
      (update_gate): Sequential(
        (0): Linear(in_features=18, out_features=20, bias=True)
        (1): Tanh()
        (2): Linear(in_features=20, out_features=5, bias=True)
        (3): Sigmoid()
      )
      (reset_gate): Sequential(
        (0): Linear(in_features=18, out_features=20, bias=True)
        (1): Tanh()
        (2): Linear(in_features=20, out_features=5, bias=True)
        (3): Sigmoid()
      )
      (new_state_net): Sequential(
        (0): Linear(in_features=18, out_features=20, bias=True)
        (1): Tanh()
        (2): Linear(in_features=20, out_features=10, bias=True)
      )
    )
    (z0_diffeq_solver): DiffeqSolver(
      (ode_func): ODEFunc(
        (gradient_net): Sequential(
          (0): Linear(in_features=5, out_features=20, bias=True)
          (1): Tanh()
          (2): Linear(in_features=20, out_features=20, bias=True)
          (3): Tanh()
          (4): Linear(in_featur

## Model Training

In [11]:
file_name = os.path.basename(os.curdir)[:-3]
# setting logs
log_path = "logs/" + file_name + "_" + str(experimentID) + ".log"
if not os.path.exists("logs/"):
    utils.makedirs("logs/")
logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(os.curdir))
logger.info(f"Run_{experimentID}")

# training parameters
optimizer = optim.Adamax(model.parameters(), lr=args.lr)
num_batches = data_obj["n_train_batches"]

c:\Mihir\Ashoka RAship\ESIMC Datathon\Code\Trajectory\latent_ode
Run_3


In [12]:
def train(kl_coef): 
    train_batch_loss = []
    train_batch_mse = []
    train_res_update = []
    for batch_no, (data_batch,mask_batch,reconst_batch) in enumerate(data_obj['train_dataloader']):
        optimizer.zero_grad()
        utils.update_learning_rate(optimizer, decay_rate = 0.999, lowest = args.lr / 10)

        batch_dict = {
            "tp_to_predict": torch.tensor(np.arange(0,TIMESTAMPS,1),dtype=torch.float32),
            "observed_data": data_batch.to(device),
            "observed_tp": torch.tensor(np.arange(0,TIMESTAMPS,1),dtype=torch.float32),
            "observed_mask": mask_batch.to(device),
            "data_to_predict": reconst_batch.to(device),
            "mask_predicted_data": None,
            "labels": None
        }
        train_res = model.compute_all_losses(batch_dict, n_traj_samples = 3, kl_coef = kl_coef)
        train_res["loss"].backward()
        optimizer.step()

        train_batch_loss.append(train_res["loss"].detach())
        train_batch_mse.append(train_res["mse"])
        train_res_update = train_res

    return train_batch_loss, train_batch_mse, train_res_update

def test(kl_coef):
    test_batch_loss = []
    test_batch_mse = []
    test_res_update = None
    for batch_no, (data_batch,mask_batch,reconst_batch) in enumerate(data_obj['test_dataloader']):
        with torch.no_grad():
            batch_dict = {
                "tp_to_predict": torch.tensor(np.arange(0,TIMESTAMPS,1),dtype=torch.float32),
                "observed_data": data_batch.to(device),
                "observed_tp": torch.tensor(np.arange(0,TIMESTAMPS,1),dtype=torch.float32),
                "observed_mask": mask_batch.to(device),
                "data_to_predict": reconst_batch.to(device),
                "mask_predicted_data": None,
                "labels": None
            }
            test_res = model.compute_all_losses(batch_dict, n_traj_samples = 3, kl_coef = kl_coef)
            test_batch_loss.append(test_res["loss"].detach())
            test_batch_mse.append(test_res["mse"])

            test_res_update = test_res

    return test_batch_loss, test_batch_mse, test_res_update

def fit():
    train_epoch_loss = []
    train_epoch_mse = []
    latest_embeddings = None

    for epoch in range(args.epochs):
        wait_until_kl_inc = 10
        kl_coef = 0 if epoch < wait_until_kl_inc else (1-0.99**(epoch - wait_until_kl_inc))
        train_batch_loss, train_batch_mse, train_res_update = train(kl_coef)
        train_epoch_loss.append(np.mean(np.array(train_batch_loss)))
        train_epoch_mse.append(np.mean(np.array(train_batch_mse)))

        if args.chkpt and (epoch+1) % args.chkpt == 0:
            print(f"Checkpoint {(epoch+1) // args.chkpt} reached...")
            test_batch_loss, test_batch_mse, test_res_update = test(kl_coef)
            message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
                epoch+1, test_res_update["loss"].detach(), test_res_update["likelihood"].detach(), test_res_update["kl_first_p"], test_res_update["std_first_p"]
            )
    
            logger.info("Experiment " + str(experimentID))
            logger.info(message)
            logger.info("KL coef: {}".format(kl_coef))
            logger.info("Train loss (one batch): {}".format(train_res_update["loss"].detach()))
            logger.info("Test MSE: {:.4f}".format(test_res_update["mse"]))

            latest_embeddings = test_res_update["latent_variables"].detach().numpy()
            torch.save(model.state_dict(), os.path.join(ckpt_path,f"state_{epoch}.pth")) 

    torch.save(model.state_dict(), os.path.join(ckpt_path,"final_model.pth"))     
    return train_epoch_loss, train_epoch_mse, latest_embeddings

In [13]:
train_epoch_loss, train_epoch_mse, latest_embeddings = fit()

Checkpoint 1 reached...


Experiment 3
Epoch 0005 [Test seq (cond on sampled tp)] | Loss 309.338898 | Likelihood -310.437531 | KL fp 5.4295 | FP STD 0.0169|
KL coef: 0
Train loss (one batch): 310.05780029296875
Test MSE: 0.0628


Checkpoint 2 reached...


Experiment 3
Epoch 0010 [Test seq (cond on sampled tp)] | Loss 309.144623 | Likelihood -310.243256 | KL fp 4.7811 | FP STD 0.0278|
KL coef: 0
Train loss (one batch): 309.6271667480469
Test MSE: 0.0628


Checkpoint 3 reached...


Experiment 3
Epoch 0015 [Test seq (cond on sampled tp)] | Loss 309.253998 | Likelihood -310.202423 | KL fp 3.8181 | FP STD 0.0653|
KL coef: 0.039403990000000055
Train loss (one batch): 309.8526611328125
Test MSE: 0.0628


Checkpoint 4 reached...


Experiment 3
Epoch 0020 [Test seq (cond on sampled tp)] | Loss 309.340240 | Likelihood -310.160583 | KL fp 3.2259 | FP STD 0.1126|
KL coef: 0.08648275251635917
Train loss (one batch): 311.0486145019531
Test MSE: 0.0628


In [16]:
latest_embeddings.shape

(3, 40, 500, 3)