In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import TensorDataset
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import pdb

from functions.load_data import MarielDataset, edges
from functions.functions import *
from functions.modules import *

##### Encoder Procedure
1. Create an embedding of the node features (H) using an MLP
2. Create a message to pass through the edges of the graph using the embedded node features (H) and another MLP
3. Aggregate the messages created in Step 2 for each node to update node features
4. Pass the updated node features through another MLP to get the "Pre-posterior"; the posterior then becomes the softmax of the pre-posterior

# Load data

In [None]:
batch_size = 32
seq_len = 50
data = MarielDataset(seq_len=seq_len, reduced_joints=False)
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
print("\nGenerated {:,} batches of shape: {}".format(len(dataloader), data[0]))

# Define model & train

In [None]:
node_features = data.seq_len*data.n_dim
edge_features = data[0].num_edge_features
node_embedding_dim = 5
hidden_size = 5
edge_embedding_dim = 10
num_layers = 10
num_edge_types = 4

encoder = MLPEncoder(node_features=node_features, 
                    edge_features=edge_features, 
                    hidden_size=hidden_size, 
                    node_embedding_dim=node_embedding_dim,
                    edge_embedding_dim=edge_embedding_dim
                    )

decoder = RNNDecoder(input_size=node_embedding_dim, 
                    output_size=node_features,
                    edge_embedding_dim=edge_embedding_dim,
                    edge_features=edge_features,
                    num_layers=num_layers
                    )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

encoder = encoder.to(device)
decoder = decoder.to(device)

print(encoder)
print(decoder)

In [None]:
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.01, weight_decay=5e-4)
loss_func = torch.nn.MSELoss(reduction='mean')
sigma = 0.001 # how to pick sigma?

def train(num_epochs):
    losses = []
    encoder.train()
    decoder.train()
    
    for epoch in tqdm(range(num_epochs)):
        loss = 0
        i = 0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad() # reset the gradients to zero

            ### ENCODER
            node_embedding, edge_index, edge_embedding, log_probabilities = encoder(batch) # dim=0 or dim=1 for probability?
                
            ### SAMPLING
            z = torch.nn.functional.gumbel_softmax(log_probabilities, tau=0.5)
            
            ### DECODER
            output = decoder(node_embedding, edge_index, edge_embedding, z)

            ### CALCULATE LOSS
            mse_loss = loss_func(batch.x.to(device), output)
            print("===== Batch {} =====".format(i))
            print("Minimum prediction: {:.2f} & maximum prediction: {:.2f}".format(torch.min(output).item(),torch.max(output).item()))
            my_nll_loss = gaussian_neg_log_likelihood(target=batch.x, mu=output, sigma=sigma)
            nll_loss = nll_gaussian(preds=output, target=batch.x.to(device), variance=5e-5)
            kl_loss = kl_categorical_uniform(torch.exp(log_probabilities), data[0].num_nodes, num_edge_types, add_const=True)
            batch_loss = mse_loss + kl_loss 
            print("MSE Loss: {:.5f} | My NLL Loss: {:.5f} | NLL Loss: {:.5f} | KL Loss: {:.5f}".format(mse_loss.item(), my_nll_loss, nll_loss, kl_loss))
            batch_loss.backward()
            optimizer.step()
            loss += batch_loss.item()
            i += 1
            if i > 5: # temporary -- for stopping training early
                break
        loss = loss / i # calculate average loss over the batches -- use len(dataloader) when running over all batches
        losses.append(loss.item())
        print("epoch : {}/{}, loss = {:.6f}".format(epoch, num_epochs, loss))
    return losses

In [None]:
losses = train(num_epochs=5)

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.plot(np.arange(len(losses)), losses)
ax.set_xlabel("Epoch", fontsize=16)
ax.set_ylabel("Loss", fontsize=16)

### Still missing: 
- How do we use edge_embedding in the decoder? (Edge weight?)