In [119]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [120]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import TensorDataset
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx
import torch.nn as nn

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from functions.load_data import MarielDataset, edges

##### Encoder Procedure
1. Create an embedding of the node features (H) using an MLP
2. Create a message to pass through the edges of the graph using the embedded node features (H) and another MLP
3. Aggregate the messages created in Step 2 for each node to update node features
4. Pass the updated node features through another MLP to get the "Pre-posterior"; the posterior then becomes the softmax of the pre-posterior

# Load data

In [121]:
batch_size = 1
data = MarielDataset(seq_len=4, reduced_joints=True)
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
print("\nGenerated {:,} sequences of shape: {}".format(len(data), data[0]))

Original numpy dataset contains 38,309 timesteps of 53 joints with 3 dimensions each.

Generating overlapping sequences...
Using (x,y)-centering...
Reducing joints...

Generated 38,305 sequences of shape: Data(edge_attr=[648, 4], edge_index=[2, 648], x=[18, 12], y=[18, 12])


In [122]:
# for batch in dataloader:
#     print("Mean: {:.5f} & Std: {:.5f}".format(torch.mean(batch.x), torch.std(batch.x)))
#     print(batch.x)

# Define model & train

In [160]:
from functions.modules import MLPEncoder, RNNDecoder
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F

node_features = data.seq_len*data.n_dim
edge_features = data[0].num_edge_features
hidden_size = 5
node_embedding_dim = 2
edge_embedding_dim = 3

encoder = MLPEncoder(node_features=node_features, 
                   edge_features=edge_features, 
                   hidden_size=hidden_size, 
                   node_embedding_dim=node_embedding_dim)

decoder = RNNDecoder(input_size=node_embedding_dim, 
                    hidden_size=hidden_size, 
                    output_size=node_features,
                    edge_features=edge_features,
                    edge_embedding_dim=edge_embedding_dim,
                    )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

encoder = encoder.to(device)
decoder = decoder.to(device)

print(encoder)
print(decoder)

Using cuda
MLPEncoder(
  (node_embedding): Sequential(
    (0): Linear(in_features=12, out_features=2, bias=True)
    (1): ReLU()
  )
  (MLP): Sequential(
    (0): Linear(in_features=4, out_features=5, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5, out_features=8, bias=True)
  )
  (graph_conv_1): MLPGraphConv(2, 2)
  (graph_conv_2): MLPGraphConv(2, 2)
  (graph_conv_list): ModuleList(
    (0): MLPGraphConv(2, 2)
    (1): MLPGraphConv(2, 2)
  )
)
RNNDecoder(
  (node_embedding): Sequential(
    (0): Linear(in_features=2, out_features=12, bias=True)
    (1): ReLU()
  )
  (edge_embedding): Sequential(
    (0): Linear(in_features=4, out_features=3, bias=True)
    (1): ReLU()
  )
  (graph_conv_1): GatedGraphConv(12, num_layers=2)
  (graph_conv_2): GatedGraphConv(12, num_layers=2)
  (graph_conv_list): ModuleList(
    (0): GatedGraphConv(12, num_layers=2)
    (1): GatedGraphConv(12, num_layers=2)
  )
)


In [161]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_func = nn.MSELoss()
# loss_func = nn.L1Loss()
dummy_matrix = torch.tensor(np.random.rand(batch_size*data[0].num_nodes,node_embedding_dim), dtype=torch.float)
# dummy_matrix = torch.tensor(np.ones((batch_size*data[0].num_nodes,node_embedding_dim)), dtype=torch.float)
dummy_matrix = dummy_matrix.to(device)

def train():
    losses = []
    model.train()
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        output, edge_index, edge_attr = encoder(batch)
#         print("Are any weights NaNs? {}".format(torch.isnan(model.node_embedding[0].weight).any())) # are any of the node embedding weights NaNs?
        distribution = MultivariateNormal(edge_attr.mean(dim=1).cpu()[:edge_embedding_dim], torch.eye(edge_embedding_dim))
        a = distribution.sample()
        a_soft = F.softmax(a, dim=-1).to(device)
        decoder(output, edge_index, edge_attr, a_soft)
#         print("Slice of output: {}".format(output[0,:5]))
        loss = loss_func(output, dummy_matrix) # would normally use batch.y here for the full AE
        print("Loss: {}".format(loss.item()))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
#         break
    return losses

In [162]:
for epoch in range(5):
    train()

  def message(self, x_i, x_j, pseudo):


Loss: 127.16078186035156
Loss: 125.07447814941406
Loss: 122.41666412353516
Loss: 120.989990234375
Loss: 118.71804809570312
Loss: 117.9212417602539
Loss: 117.67632293701172
Loss: 118.1209487915039
Loss: 119.21983337402344
Loss: 119.12320709228516
Loss: 118.07303619384766


KeyboardInterrupt: 

Still missing: 
- Edge embedding in encoder
- How do we use edge_embedding in the decoder?
- Edge weight?