In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import TensorDataset
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx
import torch.nn as nn

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from functions.load_data import MarielDataset, edges

##### Encoder Procedure
1. Create an embedding of the node features (H) using an MLP
2. Create a message to pass through the edges of the graph using the embedded node features (H) and another MLP
3. Aggregate the messages created in Step 2 for each node to update node features
4. Pass the updated node features through another MLP to get the "Pre-posterior"; the posterior then becomes the softmax of the pre-posterior

# Load data

In [None]:
batch_size = 32
data = MarielDataset(seq_len=4, reduced_joints=True)
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
print("\nGenerated {:,} sequences of shape: {}".format(len(data), data[0]))

In [None]:
# for batch in dataloader:
#     print("Mean: {:.5f} & Std: {:.5f}".format(torch.mean(batch.x), torch.std(batch.x)))
#     print(batch.x)

# Define model & train

In [None]:
from functions.modules import MLPEncoder, RNNDecoder
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F

node_features = data.seq_len*data.n_dim
edge_features = data[0].num_edge_features
hidden_size = 5
node_embedding_dim = 2
edge_embedding_dim = 3

encoder = MLPEncoder(node_features=node_features, 
                   edge_features=edge_features, 
                   hidden_size=hidden_size, 
                   node_embedding_dim=node_embedding_dim,
                   edge_embedding_dim=edge_embedding_dim)

decoder = RNNDecoder(input_size=node_embedding_dim, 
                    hidden_size=hidden_size, 
                    output_size=node_features,
                    edge_embedding_dim=edge_embedding_dim,
                    )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

encoder = encoder.to(device)
decoder = decoder.to(device)

print(encoder)
print(decoder)

In [None]:
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.01, weight_decay=5e-4)
loss_func = nn.MSELoss()

# dummy_matrix = torch.tensor(np.random.rand(batch_size*data[0].num_nodes,node_features), dtype=torch.float)
# dummy_matrix = dummy_matrix.to(device)

def train(num_epochs):
    losses = []
    encoder.train()
    decoder.train()
    
    for epoch in range(num_epochs):
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()

            ### ENCODER
            node_embedding, edge_index, edge_embedding = encoder(batch)
            if torch.isnan(encoder.node_embedding[0].weight).any(): print("NaNs detected in the node embedding weights!")

            ### SAMPLING
            ### What dimensionality should this multivariate Gaussian be? 1 distribution per edge embedding dim?
            distribution = MultivariateNormal(edge_embedding.mean(dim=0).cpu(), torch.eye(edge_embedding_dim)) # mean & covariance matrix
            a = distribution.sample()
            a_soft = F.softmax(a, dim=-1).to(device) # is dim=-1 necessary if I have dim=0 above?

            ### DECODER
            output = decoder(node_embedding, edge_index, edge_embedding, a_soft)

            ### CALCULATE LOSS
            loss = loss_func(output, batch.x.to(device))
            print("Loss: {}".format(loss.item()))
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
#             break
    return losses

In [None]:
losses = train(num_epochs=5)

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.plot(np.arange(len(losses)), losses)
ax.set_xlabel("Epoch", fontsize=16)
ax.set_ylabel("Loss", fontsize=16)

Still missing: 
- How do we use edge_embedding in the decoder? (Edge weight?)
- Gumbel softmax
- NLL Loss vs MSE Loss
- NaNs in the node embedding layer with higher-dim inputs?