In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataset import TensorDataset
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx
import torch.nn as nn

import networkx as nx # for visualizing graphs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from functions.load_data import MarielDataset, edges

##### Encoder Procedure
1. Create an embedding of the node features (H) using an MLP
2. Create a message to pass through the edges of the graph using the embedded node features (H) and another MLP
3. Aggregate the messages created in Step 2 for each node to update node features
4. Pass the updated node features through another MLP to get the "Pre-posterior"; the posterior then becomes the softmax of the pre-posterior

# Load data

In [None]:
data = MarielDataset(seq_len=10, reduced_joints=True)
dataloader = DataLoader(data, batch_size=2, shuffle=False)
print("\nGenerated {:,} sequences of shape: {}".format(len(data), data[0]))

# To do list
- Turn node embedding MLP into a linear layer (no hidden layer)
- Add depth parameter to MLPEncoder (multiple MLPGraphConv layers)

# Construct model & train

In [None]:
from functions.modules import MLP, MLPGraphConv, MLPEncoder

node_features = 30 # data.seq_len*data.n_dim
edge_features = 10 # data[0].num_edge_features

model = MLPEncoder(node_features=node_features, 
                   edge_features=edge_features, 
                   hidden_size=25, 
                   node_embedding_dim=20, 
                   depth=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_func = nn.MSELoss()
dummy_matrix = torch.tensor(np.ones((36,20)), dtype=torch.float)
dummy_matrix = dummy_matrix.to(device)

def train():
    losses = []
    model.train()
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch)
        loss = loss_func(output, dummy_matrix) # would normally use batch.y here for the full AE
        print("Loss: {}".format(loss.item()))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return losses

In [None]:
for epoch in range(10):
    train()