In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# # Install a pip package in the current Jupyter kernel
# import sys
# !{sys.executable} -m pip install torchsummary

In [None]:
import torch
from torch.utils.data import Dataset, Subset, SubsetRandomSampler, SequentialSampler
from torch.utils.data.dataset import TensorDataset
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import pdb
from torchsummary import summary
import json
import pickle
import os

from functions.load_data import MarielDataset, edges
from functions.functions import *
from functions.modules import *
from functions.seq_autoencoder import *

# Train interactively

### Load data

In [None]:
batch_size = 16
seq_len = 49
predicted_timesteps = 0
data = MarielDataset(seq_len=seq_len, reduced_joints=False, predicted_timesteps=predicted_timesteps, no_overlap=False)

train_indices = np.arange(int(0.7*len(data))) # 70% split for training data, no shuffle
val_indices = np.arange(int(0.7*len(data)),int(0.85*len(data))) # next 15% on validation
test_indices = np.arange(int(0.85*len(data)), len(data)) # last 15% on test

dataloader_train = DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=True, sampler=SequentialSampler(train_indices))
dataloader_val = DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=True, sampler=SequentialSampler(val_indices))
dataloader_test = DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=True, sampler=SequentialSampler(test_indices))

print("\nGenerated {:,} training batches of shape: {}".format(len(dataloader_train), data[0]))

### Define model

In [None]:
node_features = data.seq_len*data.n_dim
edge_features = data[0].num_edge_features  # 1 number * seq_len (skeleton connection for each edge for each timestep)
node_embedding_dim = 64
edge_embedding_dim = 32 # number of edge types
hidden_size = 64
num_layers = 3
seq_len = 49
predicted_timesteps = 0
checkpoint_loaded = False 

model = VAE(node_features=node_features, 
            edge_features=edge_features, 
            hidden_size=hidden_size, 
            node_embedding_dim=node_embedding_dim,
            edge_embedding_dim=edge_embedding_dim,
            num_layers=num_layers,
            input_size=node_embedding_dim, 
            output_size=node_features+predicted_timesteps*3,
            sampling=False,
            recurrent=True,
           )

optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-4, weight_decay=5e-4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))
model = model.to(device)

print(model)
print("Total trainable parameters: {:,}".format(count_parameters(model)))

#### Optional: load pre-trained weights

Load the whole model + weights:

In [None]:
# model = torch.load("weights/seqlen3_model.pth")

OR load the model state into the pre-existing model above:

In [None]:
checkpoint_path = "./logs/nooverlap_53joints_seqlen10_pred0/best_weights.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_checkpoint = checkpoint['loss']
checkpoint_loaded = True

### Train

In [None]:
mse_loss = torch.nn.MSELoss(reduction='mean')
prediction_to_reconstruction_loss_ratio = 0 # you might want to weight the prediction loss higher to help it compete with the larger prediction seq_len
batch_limit = 1

def train_model(epochs):
    train_losses = []
    train_reco_losses = []
    train_pred_losses = []
    val_losses = []
    val_reco_losses = []
    val_pred_losses = []
    for epoch in range(epochs):
        model.train()
        t = time.time()
        n_batches = 0
        total_train_loss = 0
        total_train_reco_loss = 0
        total_train_pred_loss = 0
        total_val_loss = 0
        total_val_reco_loss = 0
        total_val_pred_loss = 0
        
        ### TRAINING LOOP
        for batch in dataloader_train:
            batch = batch.to(device)
            
            ### CALCULATE MODEL OUTPUTS
            output = model(batch)
            
            ### CALCULATE LOSS
            train_reco_loss = mse_loss(batch.x.to(device), output[:,:node_features]) # compare first seq_len timesteps.item()
            if predicted_timesteps > 0: 
                train_pred_loss = prediction_to_reconstruction_loss_ratio*mse_loss(batch.y.to(device), output[:,node_features:]) # compare last part to unseen data
                train_loss = train_reco_loss + train_pred_loss
            else:
                train_loss = train_reco_loss

            ### ADD LOSSES TO TOTALS
            total_train_loss += train_loss.item()
            total_train_reco_loss += train_reco_loss.item()
            if predicted_timesteps > 0: 
                total_train_pred_loss += train_pred_loss.item()

            ### BACKPROPAGATE
            optimizer.zero_grad() # reset the gradients to zero
            train_loss.backward()
            optimizer.step()

            ### OPTIONAL -- STOP TRAINING EARLY
            n_batches += 1
            if (batch_limit > 0) and (n_batches >= batch_limit): break # temporary -- for stopping training early
        
        ### VALIDATION LOOP
        model.eval()
        for batch in dataloader_val:
            batch = batch.to(device)
            
            ### CALCULATE MODEL OUTPUTS
            output = model(batch)
            
            ### CALCULATE LOSS
            val_reco_loss = mse_loss(batch.x.to(device), output[:,:node_features]) # compare first seq_len timesteps.item()
            if predicted_timesteps > 0: 
                val_pred_loss = prediction_to_reconstruction_loss_ratio*mse_loss(batch.y.to(device), output[:,node_features:]) # compare last part to unseen data
                val_loss = val_reco_loss + val_pred_loss
            else:
                val_loss = val_reco_loss

            ### ADD LOSSES TO TOTALS
            total_val_loss += val_loss.item()
            total_val_reco_loss += val_reco_loss.item()
            if predicted_timesteps > 0: 
                total_val_pred_loss += val_pred_loss.item()

            ### OPTIONAL -- STOP TRAINING EARLY
            n_batches += 1
            if (batch_limit > 0) and (n_batches >= batch_limit): break # temporary -- for stopping training early
        
        ### CALCULATE AVERAGE LOSSES PER EPOCH   
        epoch_train_loss = total_train_loss / n_batches
        epoch_train_reco_loss = total_train_reco_loss / n_batches
        epoch_train_pred_loss = total_train_pred_loss / n_batches

        train_losses.append(epoch_train_loss) 
        train_reco_losses.append(epoch_train_reco_loss)
        train_pred_losses.append(epoch_train_pred_loss)

        epoch_val_loss = total_val_loss / n_batches
        epoch_val_reco_loss = total_val_reco_loss / n_batches
        epoch_val_pred_loss = total_val_pred_loss / n_batches

        val_losses.append(epoch_val_loss) 
        val_reco_losses.append(epoch_val_reco_loss)
        val_pred_losses.append(epoch_val_pred_loss)

        print("epoch : {}/{} | train_loss = {:,.4f} | train_reco_loss: {:,.4f} | train_pred_loss: {:,.4f} | val_loss = {:,.4f} | val_reco_loss: {:,.4f} | val_pred_loss: {:,.4f} |time: {:.4f} sec".format(epoch+1, epochs, 
                                                                                                                epoch_train_loss,
                                                                                                                epoch_train_reco_loss, 
                                                                                                                epoch_train_pred_loss,
                                                                                                                epoch_val_loss,
                                                                                                                epoch_val_reco_loss, 
                                                                                                                epoch_val_pred_loss,
                                                                                                                time.time() - t))
        
        if epoch == 0 and not checkpoint_loaded: best_loss = epoch_val_loss
        elif epoch == 0 and checkpoint_loaded: best_loss = min(epoch_val_loss, loss_checkpoint)
            
        if epoch_val_loss < best_loss:
            best_loss = epoch_val_loss
#             torch.save({
#              'epoch': epoch,
#              'model_state_dict': model.state_dict(),
#              'optimizer_state_dict': optimizer.state_dict(),
#              'loss': best_loss,
#              }, checkpoint_path)
#             print("Better loss achieved -- saved model checkpoint to {}.".format(checkpoint_path))

    loss_dict = {
	"train_losses": train_losses,
	"train_reco_losses": train_reco_losses,
	"train_pred_losses": train_pred_losses,
	"val_losses": val_losses,
	"val_reco_losses": val_reco_losses,
	"val_pred_losses": val_pred_losses,
			}

In [None]:
train_model(epochs=2)

# Test

In [None]:
! ls ./logs/

In [None]:
folder = "./logs/vae_moreparams__53joints_seqlen49_pred0"
dataloader_test = torch.load(os.path.join(folder,"dataloader_test.pth"))
checkpoint_path = os.path.join(folder,"best_weights.pth")
dict = json.load(open(os.path.join(folder,"losses.json")))
train_losses = dict['train_losses']
val_losses = dict['val_losses']

In [None]:
n_joints = 53
seq_len = 49
batch_size = 32
predicted_timesteps = 0
node_features = seq_len*3 # data.seq_len*data.n_dim
edge_features = seq_len # data[0].num_edge_features
node_embedding_dim = 40
edge_embedding_dim = 4 # number of edge types
hidden_size = 515
num_layers = 4
checkpoint_loaded = False 

model = VAE(node_features=node_features, 
            edge_features=edge_features, 
            hidden_size=hidden_size, 
            node_embedding_dim=node_embedding_dim,
            edge_embedding_dim=edge_embedding_dim,
            num_layers=num_layers,
            input_size=node_embedding_dim, 
            output_size=node_features+predicted_timesteps*3,
           )

optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-4, weight_decay=5e-4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))
model = model.to(device)
print(model)
print("Total trainable parameters: {:,}".format(count_parameters(model)))

In [None]:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_checkpoint = checkpoint['loss']
checkpoint_loaded = True

In [None]:
def test():
    mse_loss = torch.nn.MSELoss(reduction='mean')
    prediction_to_reconstruction_loss_ratio = 0
    total_test_loss = 0
    total_test_reco_loss = 0
    total_test_pred_loss = 0
    n_batches = 0
    actuals = []
    preds = []
    model.eval()
    for batch in tqdm(dataloader_test, desc="Test batches"):
        batch = batch.to(device)

        ### CALCULATE MODEL OUTPUTS
        output = model(batch)
        
        ### SAVE FOR ANIMATIONS
        actuals.append(batch.x.detach().cpu().numpy())
        preds.append(output.detach().cpu().numpy())

        ### CALCULATE LOSS
        test_reco_loss = mse_loss(batch.x.to(device), output[:,:node_features]) # compare first seq_len timesteps
        if predicted_timesteps > 0: 
            test_pred_loss = prediction_to_reconstruction_loss_ratio*mse_loss(batch.y.to(device), output[:,node_features:]) # compare last part to unseen data
            test_loss = test_reco_loss + test_pred_loss
        else:
            test_loss = test_reco_loss

        ### ADD LOSSES TO TOTALS
        total_test_loss += test_loss.item()
        total_test_reco_loss += test_reco_loss.item()
        if predicted_timesteps > 0: 
            total_test_pred_loss += test_pred_loss.item()
        n_batches += 1
        
        if n_batches > 1: break ### OPTIONAL: STOP EARLY
            
    ### CALCULATE AVERAGE LOSSES PER EPOCH   
    average_test_loss = total_test_loss / n_batches
    average_test_reco_loss = total_test_reco_loss / n_batches
    average_test_pred_loss = total_test_pred_loss / n_batches
    print("Loss = {:,.4f} | Reconstruction Loss: {:,.4f} | Prediction Loss: {:,.4f}".format(average_test_loss, 
                                                                                            average_test_reco_loss, 
                                                                                            average_test_pred_loss))
    
    return actuals, preds

In [None]:
actuals, preds = test()

In [None]:
batch_number = 0
truth_sequences = []
predicted_sequences = []

for seq_number in np.arange(batch_size):
    actual = actuals[batch_number][seq_number*n_joints:seq_number*n_joints+n_joints].reshape((n_joints,seq_len,3))
    pred = preds[batch_number][seq_number*n_joints:seq_number*n_joints+n_joints].reshape((n_joints,seq_len,3))
    actual = np.transpose(actual, [1,0,2])
    pred = np.transpose(pred, [1,0,2])
    truth_sequences.append(actual)
    predicted_sequences.append(pred)
    
truth_sequences = np.asarray(truth_sequences).reshape((batch_size*seq_len, n_joints, 3))
predicted_sequences = np.asarray(predicted_sequences).reshape((batch_size*seq_len, n_joints, 3))

In [None]:
start_index = 0
# timesteps = seq_len*batch_size
timesteps = 100
animation = animate_stick(truth_sequences[start_index:start_index+timesteps,:,:], 
                          ghost=predicted_sequences[start_index:start_index+timesteps,:,:], 
                          ghost_shift=0.,
                          ax_lims = (-0.7,0.7),
                          figsize=(10,8), cmap='inferno')
HTML(animation.to_html5_video())

In [None]:
val_losses_moreparams = [1.2585,0.2413,0.2063,0.3304,0.1753,0.1665,0.1516,0.2043,0.1444,32.0709,0.1563,15.5246,0.7867,1.8139,2.1597,0.1700,0.1611,0.1606,0.1605,0.1671,0.1589,0.7795,0.1553,0.1550,0.1546,0.1544,0.1546,0.1546,0.1522,0.1533,0.1545,0.1509,0.1521,0.1538,0.1538,0.1467,0.1402,0.1433,0.1413,0.1402,0.1425,0.1419,0.1372,0.1359,0.1319,0.1331,0.1354,0.1371,0.1360,0.1283,0.1283,0.1277,0.1277,0.1271,0.1353,0.1353,0.1390,0.1249,0.0118,0.0040,0.1584,0.1568,0.1558,0.1556,0.1556,0.1554,0.1536,0.1535,0.1533,0.1523,0.1419,0.1338,]

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
# ax.plot(np.arange(len(nooverlap_val_losses)), nooverlap_val_losses, label="Validation (No Overlap)")
ax.plot(np.arange(len(val_losses)), val_losses, label="Graph VAE (250k params)")
ax.plot(np.arange(len(val_losses_moreparams)), val_losses_moreparams, label="Graph VAE (1.8M params)")
ax.set_xlabel("Epoch", fontsize=16)
ax.set_ylabel("Validation Reco Loss", fontsize=16)
# ax.set_yscale("log")
ax.set_ylim(-0.1,0.5)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
ax.legend(fontsize=14)

### Up next:
- Try to overfit the input data (loss = 0) by setting prediction weight = 0
- Weight the prediction loss to ensure that the most immediate steps are more important to reconstruct than far future steps (like 1/2^t or something)

### For later:
- The Gaussian negative log likelihood loss functions will only make sense when the output of the decoder is mu (eq'n 16 & 17)

### Done
- ~~Predict 50 + k timesteps w/ separate MSE losses~~
- ~~Look at VAE outputs!~~
- ~~Look at NRI outputs!~~

# Scratch work

In [None]:
#             my_nll_loss = gaussian_neg_log_likelihood(x=batch.x, mu=output, sigma=sigma)
#             nll_loss = nll_gaussian(preds=output, target=batch.x.to(device), variance=5e-5)
#             kl_loss = kl_categorical_uniform(torch.exp(log_probabilities), data[0].num_nodes, num_edge_types, add_const=True)

In [None]:
# start_index = 0
# timesteps = seq_len
# animation = animate_stick(actual[start_index:start_index+timesteps,:,:], 
#                           ghost=pred[start_index:start_index+timesteps,:,:], 
#                           ghost_shift=0.4,
#                           ax_lims = (-0.7,0.7),
#                           figsize=(10,8), cmap='inferno')
# HTML(animation.to_html5_video())