In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# # Install a pip package in the current Jupyter kernel
# import sys
# !{sys.executable} -m pip install torchsummary

In [None]:
import torch
from torch.utils.data import Dataset, Subset, SubsetRandomSampler, SequentialSampler
from torch.utils.data.dataset import TensorDataset
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import pdb
from torchsummary import summary
import json
import pickle

from functions.load_data import MarielDataset, edges
from functions.functions import *
from functions.modules import *
from train import train

# Load data

In [None]:
batch_size = 32
seq_len = 10
predicted_timesteps = 0
data = MarielDataset(seq_len=seq_len, reduced_joints=False, predicted_timesteps=predicted_timesteps, no_overlap=True)

train_indices = np.arange(int(0.7*len(data))) # 70% split for training data, no shuffle
val_indices = np.arange(int(0.7*len(data)),int(0.85*len(data))) # next 15% on validation
test_indices = np.arange(int(0.85*len(data)), len(data)) # last 15% on test

dataloader_train = DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=True, sampler=SequentialSampler(train_indices))
dataloader_val = DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=True, sampler=SequentialSampler(val_indices))
dataloader_test = DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=True, sampler=SequentialSampler(test_indices))

print("\nGenerated {:,} batches of shape: {}".format(len(dataloader_train), data[0]))

OR load from saved dataloaders...

In [None]:
dataloader_train = torch.load("./logs/nooverlap_53joints_seqlen10_pred0/dataloader_train.pth")
dataloader_val = torch.load("./logs/nooverlap_53joints_seqlen10_pred0/dataloader_val.pth")
dataloader_test = torch.load("./logs/nooverlap_53joints_seqlen10_pred0/dataloader_test.pth")

# Define model

In [None]:
node_features = 30 # data.seq_len*data.n_dim
edge_features = 10 # data[0].num_edge_features
node_embedding_dim = 25
edge_embedding_dim = 4 # number of edge types
hidden_size = 50
num_layers = 2
seq_len = 10
predicted_timesteps = 0
checkpoint_loaded = False 

model = VAE(node_features=node_features, 
            edge_features=edge_features, 
            hidden_size=hidden_size, 
            node_embedding_dim=node_embedding_dim,
            edge_embedding_dim=edge_embedding_dim,
            num_layers=num_layers,
            input_size=node_embedding_dim, 
            output_size=node_features+predicted_timesteps*3,
           )

optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-4, weight_decay=5e-4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))
model = model.to(device)

print(model)
print("Total trainable parameters: {:,}".format(count_parameters(model)))

### Optional: load pre-trained weights

Load the whole model + weights:

In [None]:
# model = torch.load("weights/seqlen3_model.pth")

OR load the model state into the pre-existing model above:

In [None]:
checkpoint_path = "./logs/nooverlap_53joints_seqlen10_pred0/best_weights.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_checkpoint = checkpoint['loss']
checkpoint_loaded = True

# Train

In [None]:
mse_loss = torch.nn.MSELoss(reduction='mean')
prediction_to_reconstruction_loss_ratio = 0 # you might want to weight the prediction loss higher to help it compete with the larger prediction seq_len

In [None]:
train(num_epochs=1)

# Test

In [None]:
def test():
    mse_loss = torch.nn.MSELoss(reduction='mean')
    prediction_to_reconstruction_loss_ratio = 0
    total_test_loss = 0
    total_test_reco_loss = 0
    total_test_pred_loss = 0
    n_batches = 0
    model.eval()
    for batch in tqdm(dataloader_test, desc="Test batches"):
        batch = batch.to(device)

        ### CALCULATE MODEL OUTPUTS
        output = model(batch)
        
        ### CALCULATE LOSS
        test_reco_loss = mse_loss(batch.x.to(device), output[:,:node_features]) # compare first seq_len timesteps
        if predicted_timesteps > 0: 
            test_pred_loss = prediction_to_reconstruction_loss_ratio*mse_loss(batch.y.to(device), output[:,node_features:]) # compare last part to unseen data
            test_loss = test_reco_loss + test_pred_loss
        else:
            test_loss = test_reco_loss

        ### ADD LOSSES TO TOTALS
        total_test_loss += test_loss.item()
        total_test_reco_loss += test_reco_loss.item()
#         if args.predicted_timesteps > 0: 
#             total_test_pred_loss += test_pred_loss.item()
        n_batches += 1
            
    ### CALCULATE AVERAGE LOSSES PER EPOCH   
    average_test_loss = total_test_loss / n_batches
    average_test_reco_loss = total_test_reco_loss / n_batches
    average_test_pred_loss = total_test_pred_loss / n_batches
    print("Loss = {:,.4f} | Reconstruction Loss: {:,.4f} | Prediction Loss: {:,.4f}".format(average_test_loss, 
                                                                                            average_test_reco_loss, 
                                                                                            average_test_pred_loss))

In [None]:
test()

# Load losses & predictions from pickle/json files

In [None]:
dict = json.load(open("./logs/vae_53joints_seqlen49_pred0/losses.json"))
losses = dict['overall_losses']
reconstruction_losses = dict['reconstruction_losses']
prediction_losses = dict['prediction_losses']

In [None]:
inputs = np.load("./logs/vae_53joints_seqlen49_pred0/train_inputs.npy")
outputs = np.load("./logs/vae_53joints_seqlen49_pred0/train_outputs.npy")

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
# ax.plot(np.arange(len(losses)), losses, label="Total")
ax.plot(np.arange(len(losses)), reconstruction_losses, label="Reconstruction")
# ax.plot(np.arange(len(losses)), prediction_losses, label="Prediction")
ax.set_xlabel("Epoch", fontsize=16)
ax.set_ylabel("Loss", fontsize=16)
# ax.set_yscale("log")
# ax.set_ylim(0,1)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
ax.legend(fontsize=14)

In [None]:
first_input_batch = inputs[0]
n_joints = int(first_input_batch.shape[0]/32)
first_input_seq = first_input_batch[:n_joints, :]

# reshape to be n_joints x n_timesteps x n_dim
first_input_seq = first_input_seq.reshape((first_input_seq.shape[0],int(first_input_seq.shape[1]/3),3))

In [None]:
first_predicted_batch = outputs[0]
n_joints = int(first_predicted_batch.shape[0]/32)
first_predicted_seq = first_predicted_batch[:n_joints, :]

# reshape to be n_joints x n_timesteps x n_dim
first_predicted_seq = first_predicted_seq.reshape((first_predicted_seq.shape[0],int(first_predicted_seq.shape[1]/3),3))

In [None]:
plt.figure(figsize=(10,7))
for joint in range(1): # first few joints
# for joint in range(first_seq.shape[0]): # all joints
    # plot x & y for the sequence
    plt.plot(first_input_seq[joint,:,0], first_input_seq[joint,:,1], 'o--', label="Input Joint "+str(joint)) 
    plt.plot(first_predicted_seq[joint,:,0], first_predicted_seq[joint,:,1], 'o--', label="Predicted Joint "+str(joint)) 
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=12)

### Up next:
- Try to overfit the input data (loss = 0) by setting prediction weight = 0
- Weight the prediction loss to ensure that the most immediate steps are more important to reconstruct than far future steps (like 1/2^t or something)
- Look at VAE outputs! 
- Look at NRI outputs!

### For later:
- The Gaussian negative log likelihood loss functions will only make sense when the output of the decoder is mu (eq'n 16 & 17)

### Done
- ~~Predict 50 + k timesteps w/ separate MSE losses~~

# Scratch work

In [None]:
#             my_nll_loss = gaussian_neg_log_likelihood(x=batch.x, mu=output, sigma=sigma)
#             nll_loss = nll_gaussian(preds=output, target=batch.x.to(device), variance=5e-5)
#             kl_loss = kl_categorical_uniform(torch.exp(log_probabilities), data[0].num_nodes, num_edge_types, add_const=True)