In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import TensorDataset
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from functions.load_data import MarielDataset, edges
from functions.modules import *

##### Encoder Procedure
1. Create an embedding of the node features (H) using an MLP
2. Create a message to pass through the edges of the graph using the embedded node features (H) and another MLP
3. Aggregate the messages created in Step 2 for each node to update node features
4. Pass the updated node features through another MLP to get the "Pre-posterior"; the posterior then becomes the softmax of the pre-posterior

# Load data

In [9]:
batch_size = 64
seq_len = 10
data = MarielDataset(seq_len=seq_len, reduced_joints=True)
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
print("\nGenerated {:,} batches of shape: {}".format(len(dataloader), data[0]))

Original numpy dataset contains 38,309 timesteps of 53 joints with 3 dimensions each.

Generating overlapping sequences...
Using (x,y)-centering...
Reducing joints...

Generated 599 batches of shape: Data(edge_attr=[648, 10], edge_index=[2, 648], x=[18, 30], y=[18, 30])


# Define model & train

In [13]:
node_features = data.seq_len*data.n_dim
edge_features = data[0].num_edge_features
node_embedding_dim = 50
hidden_size = 50
edge_embedding_dim = 10
num_layers = 10

encoder = MLPEncoder(node_features=node_features, 
                   edge_features=edge_features, 
                   hidden_size=hidden_size, 
                   node_embedding_dim=node_embedding_dim,
                   edge_embedding_dim=edge_embedding_dim)

decoder = RNNDecoder(input_size=node_embedding_dim, 
                    hidden_size=hidden_size, 
                    output_size=node_features,
                    edge_embedding_dim=edge_embedding_dim,
                    num_layers=num_layers
                    )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

encoder = encoder.to(device)
decoder = decoder.to(device)

print(encoder)
print(decoder)

Using cuda
MLPEncoder(
  (node_embedding): Sequential(
    (0): Linear(in_features=30, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=50, bias=True)
    (5): ReLU()
    (6): Linear(in_features=50, out_features=50, bias=True)
    (7): ReLU()
    (8): Linear(in_features=50, out_features=50, bias=True)
    (9): ReLU()
  )
  (edge_embedding): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
  )
  (MLP): Sequential(
    (0): Linear(in_features=10, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=5000, bias=True)
  )
  (graph_conv_1): MLPGraphConv(50, 50)
  (graph_conv_2): MLPGraphConv(50, 50)
  (graph_conv_list): ModuleList(
    (0): MLPGraphConv(50, 50)
    (1): MLPGraphConv(50, 50)
  )
)
RNNDecoder(
  (node_transform): Sequential(
    (0): Linear(in_features=50, out_features=30, bias=True)
    (

In [14]:
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.01, weight_decay=5e-4)
loss_func = torch.nn.MSELoss(reduction='mean')
# dummy_matrix = torch.tensor(np.random.rand(batch_size*data[0].num_nodes,node_features), dtype=torch.float)
# dummy_matrix = dummy_matrix.to(device)

def train(num_epochs):
    losses = []
    encoder.train()
    decoder.train()
    
    for epoch in range(num_epochs):
        loss = 0
        i = 0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad() # reset the gradients to zero

            ### ENCODER
            if torch.isnan(encoder.node_embedding[0].weight).any(): 
                print("NaNs detected in the node embedding weights BEFORE encoder!")
                break
            node_embedding, edge_index, edge_embedding = encoder(batch)
            if torch.isnan(encoder.node_embedding[0].weight).any(): 
                print("NaNs detected in the node embedding weights AFTER encoder!")
                break
            ### SAMPLING
#             edges = gumbel_softmax(edge_embedding, tau=0.5, hard=False)
#             prob = F.softmax(edge_embedding)
            
            ### What dimensionality should this multivariate Gaussian be? 1 distribution per edge embedding dim?
            distribution = MultivariateNormal(edge_embedding.mean(dim=0).cpu(), torch.eye(edge_embedding_dim)) # mean & covariance matrix
            z = distribution.sample()
            z_soft = F.softmax(z, dim=-1).to(device) # is dim=-1 necessary if I have dim=0 above?

            ### DECODER
            output = decoder(node_embedding, edge_index, edge_embedding, z_soft)

            ### CALCULATE LOSS
#             loss = loss_func()
            print("Maximum prediction:",torch.max(output).item())
#             print(torch.sum((output-batch.x.to(device))**2))
            batch_loss = nll_gaussian(preds=output, target=batch.x.to(device), variance=5e-5)
            print("Batch {}, Loss: {}".format(i, batch_loss.item()))
            batch_loss.backward()
            optimizer.step()
            losses.append(batch_loss.item())
            loss += batch_loss.item()
            i += 1
        loss = loss / len(dataloader)
        print("epoch : {}/{}, loss = {:.6f}".format(epoch, num_epochs, loss))
    return losses

In [15]:
losses = train(num_epochs=1)

Maximum prediction: 1.0
Batch 0, Loss: 8444.435546875
Maximum prediction: 1.0
Batch 1, Loss: 9669.2197265625


KeyboardInterrupt: 

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.plot(np.arange(len(losses)), losses)
ax.set_xlabel("Epoch", fontsize=16)
ax.set_ylabel("Loss", fontsize=16)

Still missing: 
- How do we use edge_embedding in the decoder? (Edge weight?)
- Gumbel softmax


- NLL Loss vs MSE Loss
- NaNs in the node embedding layer with higher-dim inputs? -- does reducing MSELoss w/ either sum or mean help?
- Adding more layers to node embedding?