In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import TensorDataset
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import pdb

from functions.load_data import MarielDataset, edges
from functions.functions import *
from functions.modules import *

# Load data

In [None]:
batch_size = 32
seq_len = 50
predicted_timesteps = 10
data = MarielDataset(seq_len=seq_len, reduced_joints=True, predicted_timesteps=predicted_timesteps)
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
print("\nGenerated {:,} batches of shape: {}".format(len(dataloader), data[0]))

# Define model & train

In [None]:
node_features = data.seq_len*data.n_dim
edge_features = data[0].num_edge_features
node_embedding_dim = 25
hidden_size = 25
edge_embedding_dim = 10
num_layers = 2
num_edge_types = 2

encoder = MLPEncoder(node_features=node_features, 
                    edge_features=edge_features, 
                    hidden_size=hidden_size, 
                    node_embedding_dim=node_embedding_dim,
                    edge_embedding_dim=edge_embedding_dim
                    )

decoder = RNNDecoder(input_size=node_embedding_dim, 
                    output_size=node_features+predicted_timesteps*data.n_dim,
                    edge_embedding_dim=edge_embedding_dim,
                    edge_features=edge_features,
                    num_layers=num_layers
                    )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

encoder = encoder.to(device)
decoder = decoder.to(device)

print(encoder)
print(decoder)

In [None]:
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.01, weight_decay=5e-4)
mse_loss = torch.nn.MSELoss(reduction='mean')
prediction_to_reconstruction_loss_ratio = 1 # you might want to weight the prediction loss higher to help it compete with the larger prediction seq_len
sigma = 0.001 # how to pick sigma?

def train(num_epochs):
    losses = []
    reconstruction_losses = []
    prediction_losses = []

    encoder.train()
    decoder.train()
    
    for epoch in tqdm(range(num_epochs)):
        average_loss = 0
        average_reconstruction_loss = 0
        average_prediction_loss = 0
        i = 0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad() # reset the gradients to zero

            ### ENCODER
            node_embedding, edge_index, edge_embedding, log_probabilities = encoder(batch) # dim=0 or dim=1 for probability?
                
            ### SAMPLING
            z = torch.nn.functional.gumbel_softmax(log_probabilities, tau=0.5)
            
            ### DECODER
            output = decoder(node_embedding, edge_index, edge_embedding, z)

            ### CALCULATE LOSS
            reconstruction_loss = mse_loss(batch.x.to(device), output[:,:node_features]) # compare first seq_len timesteps
            prediction_loss = mse_loss(batch.y.to(device), output[:,node_features:]) # compare last part to unseen data
#             print("===== Batch {} =====".format(i))
#             print("Minimum prediction: {:.2f} & maximum prediction: {:.2f}".format(torch.min(output).item(),torch.max(output).item()))
#             my_nll_loss = gaussian_neg_log_likelihood(x=batch.x, mu=output, sigma=sigma)
#             nll_loss = nll_gaussian(preds=output, target=batch.x.to(device), variance=5e-5)
#             kl_loss = kl_categorical_uniform(torch.exp(log_probabilities), data[0].num_nodes, num_edge_types, add_const=True)
            batch_loss = reconstruction_loss + prediction_to_reconstruction_loss_ratio*prediction_loss
            batch_loss.backward()
            optimizer.step()
            average_loss += batch_loss.item()
            average_reconstruction_loss += reconstruction_loss.item()
            average_prediction_loss += prediction_loss.item()
            i += 1
            if i > 50: # temporary -- for stopping training early
                break
        average_loss = average_loss / i # calculate average loss over the batches -- use len(dataloader) when running over all batches
        average_reconstruction_loss = average_reconstruction_loss / i # calculate average loss over the batches -- use len(dataloader) when running over all batches
        average_prediction_loss = average_prediction_loss / i # calculate average loss over the batches -- use len(dataloader) when running over all batches
        losses.append(average_loss)
        reconstruction_losses.append(average_reconstruction_loss)
        prediction_losses.append(average_prediction_loss)
        print("epoch : {}/{} | Loss = {:,.3f} | Reconstruction Loss: {:,.3f} | Prediction Loss: {:,.3f}".format(epoch+1, num_epochs, average_loss, average_reconstruction_loss, average_prediction_loss))
    return losses, reconstruction_losses, prediction_losses

In [None]:
losses, reconstruction_losses, prediction_losses = train(num_epochs=5)

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.plot(np.arange(len(losses)), losses, label="Total")
ax.plot(np.arange(len(losses)), reconstruction_losses, label="Reconstruction")
ax.plot(np.arange(len(losses)), prediction_losses, label="Prediction")
ax.set_xlabel("Epoch", fontsize=16)
ax.set_ylabel("Loss", fontsize=16)
ax.legend(fontsize=14)

### Up next:
- use edge index to create discrete decoders
- How do we use edge_embedding in the decoder? (Edge weight?)]
- Should we do an additional edge transform in the decoder?

### For later:
- The Gaussian negative log likelihood loss functions will only make sense when the output of the decoder is mu (eq'n 16 & 17)

### Done
- ~~Predict 50 + k timesteps w/ separate MSE losses~~