## Task 3

Conditional VAE 

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# use ggplot style
plt.style.use('ggplot')


def load_array(filename, task):
    datapoint = np.load(filename)
    if task == 'task 1':
        initial_state = datapoint['initial_state']
        terminal_state = datapoint['terminal_state']
        return initial_state, terminal_state
    elif task == 'task 2' or task == 'task 3':
        whole_trajectory = datapoint['trajectory']
        # change shape: (num_bodies, attributes, time) ->  num_bodies, time, attributes
        whole_trajectory = np.swapaxes(whole_trajectory, 1, 2)
        initial_state = whole_trajectory[:, 0]
        target = whole_trajectory[:, 1:, 1:]  # drop the first timepoint (second dim) and mass (last dim) for the prediction task
        return initial_state, target
    else:
        raise NotImplementedError("'task' argument should be 'task 1', 'task 2' or 'task 3'!")

#### Create adjacency matrix

# Define distance metrics
def euclidean_distance(x, y):
    return torch.sqrt(torch.sum((x - y)**2))

def inverse_distance(x, y):
    return 1 / euclidean_distance(x, y)

# Create adjacency matrix function
def create_adjacency_matrix(data, distance_metric):
    n = data.shape[0]
    adjacency_matrix = torch.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:  # we don't calculate the distance of the object to itself
                # we extract the position [x, y] for both objects i and j
                position_i = data[i, 1:3]
                position_j = data[j, 1:3]
                adjacency_matrix[i, j] = distance_metric(position_i, position_j)
    return adjacency_matrix

# Validate input
def validate_input(X, adjacency_matrix):
    # X should be a 2D tensor
    assert X.dim() == 2, f"X must be 2D, but got shape {X.shape}"

    # The number of nodes should be the same in X and the adjacency matrix
    assert X.shape[0] == adjacency_matrix.shape[0] == adjacency_matrix.shape[1], \
        f"Mismatch in number of nodes: got {X.shape[0]} nodes in X, but {adjacency_matrix.shape[0]} nodes in adjacency matrix"

    # The adjacency matrix should be square
    assert adjacency_matrix.shape[0] == adjacency_matrix.shape[1], \
        f"Adjacency matrix must be square, but got shape {adjacency_matrix.shape}"

    print("All checks passed.")


In [83]:
device = "mps" # NOTE: Define device here ONCE
from torch_geometric.data import Data, DataLoader, Dataset

# DataLoaders for task 2
class MyDataset(Dataset):
    def __init__(self, root, filenames, transform=None, pre_transform=None):
        self.filenames = filenames
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return self.filenames

    def len(self):
        return len(self.filenames)

    def get(self, idx):
        X, y = load_array(self.filenames[idx], task='task 2')
        X = torch.tensor(X, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.float32)   # y is now a 3D tensor

        adjacency_matrix = create_adjacency_matrix(X, inverse_distance)
        edge_index = adjacency_matrix.nonzero().t().contiguous().to(torch.long)

        data = Data(x=X, y=y, edge_index=edge_index)  # y is now a 3D tensor

        return data


filenames = [f'data/task 2_3/train/trajectory_{i}.npz' for i in range(900)]
split_point = int(len(filenames) * 0.8)
# Do training validation split
train_filenames = filenames[:split_point]
val_filenames = filenames[split_point:]

train_dataset = MyDataset(root='data/task 2_3/train', filenames=train_filenames)
train_dataloader = DataLoader(train_dataset, batch_size=1)

val_dataset = MyDataset(root='data/task 2_3/train', filenames=val_filenames)
val_dataloader = DataLoader(val_dataset, batch_size=1)

# Prepare for validation data set

test_filenames = [f'data/task 2_3/test/trajectory_{i}.npz' for i in range(901, 1000)]
test_dataset = MyDataset(root='data/task 2_3/test', filenames=test_filenames)
test_dataloader = DataLoader(test_dataset, batch_size=1)



In [7]:
# Print shape of first batch in train_dataloader
for batch in train_dataloader:
    graph = batch[3]
    break

graph

Data(x=[6, 5], edge_index=[2, 30], y=[6, 49, 4])

In [77]:
# Encoder

import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv

class Encoder(nn.Module):
    def __init__(self, num_features, hidden_channels, latent_dim):
        super(Encoder, self).__init__()
        
        # Graph convolution layers
        self.conv1 = SAGEConv(num_features, hidden_channels*2)
        self.conv2 = SAGEConv(hidden_channels*2, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        
        # LSTM layer
        self.lstm = nn.LSTM(input_size=hidden_channels, hidden_size=hidden_channels, batch_first=True)
        
        # Linear layers to compute mean and log variance of latent space
        self.fc_mu = nn.Linear(hidden_channels, latent_dim)
        self.fc_var = nn.Linear(hidden_channels, latent_dim)
        
    def forward(self, data):
        # data has edge_index for graph structure
        # trajectory has shape (n, 49, 4), where n is the number of objects, 49 is timesteps, and 4 is (x, y, velocity_x, velocity_y)
        
        x, edge_index = data.y , data.edge_index

        # Graph convolution layers
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x = self.conv3(x, edge_index)
        x = torch.relu(x)
        
        # LSTM layer to encode temporal information
        # Assuming x now has shape (n, 49, hidden_channels) - one feature vector for each node at each timestep
        lstm_outputs, _ = self.lstm(x)
        final_output = lstm_outputs[:, -1, :]

        # Linear layers for mean and log variance
        mu = self.fc_mu(final_output)
        log_var = self.fc_var(final_output)

        return mu, log_var

# Example usage:
latent_dim = 16
encoder = Encoder(num_features=4, hidden_channels=64, latent_dim=latent_dim)

# Example usage:
latent_dim = 16
encoder = Encoder(num_features=4, hidden_channels=64, latent_dim=latent_dim)


In [26]:
# Conditional prior

class MinGraphSAGE(torch.nn.Module):
    # FINAL MODEL!
    def __init__(self, num_features, hidden_channels, latent_dim):
        super(MinGraphSAGE, self).__init__()
        self.conv1 = SAGEConv(num_features, hidden_channels*2)
        self.conv2 = SAGEConv(hidden_channels*2, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, latent_dim)

        self.dropout = torch.nn.Dropout(p=0.3)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # 1st GraphSAGE layer
        x = self.conv1(x, edge_index)
        x = torch.relu(x)

        # Droput
        x = self.dropout(x)

        # 2nd GraphSAGE layer
        x = self.conv2(x, edge_index)
        x = torch.relu(x)

        # 3rd GraphSAGE layer
        x = self.conv3(x, edge_index)

        return x

# Example usage:
latent_dim = 16
min_graph_sage = MinGraphSAGE(num_features=5, hidden_channels=64, latent_dim=latent_dim)


In [29]:
# Decoder

class Decoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super(Decoder, self).__init__()
        # input_dim: dimension of the concatenated latent representation and conditioning vector
        # hidden_dim: number of hidden units in the LSTM
        # output_dim: dimension of the output at each timestep (4 in this case for x, y, vel_x, vel_y)
        # num_layers: number of layers in the LSTM
        
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.num_timesteps = 49

    def forward(self, x):
        # x: concatenated latent representation and conditioning vectors of shape [n, 32]
        
        # Repeat x along a new temporal dimension to create a sequence, with the repeated
        # x as the input at each time step. This is necessary because the LSTM expects
        # input of the form (batch_size, sequence_length, input_size).
        x = x.unsqueeze(1).repeat(1, self.num_timesteps, 1)
        
        # Pass the sequence through the LSTM
        lstm_out, _ = self.lstm(x)  # lstm_out has shape [n, 49, hidden_dim]
        
        # Pass the LSTM output through a fully connected layer to get the final output
        output = self.fc(lstm_out)  # output has shape [n, 49, 4]
        
        return output

# Example usage:
decoder = Decoder(input_dim=32, hidden_dim=64, output_dim=4)


In [58]:
# Test for a trial run
# Encode the trajectories 
# 16 dimensions for mu, logvar PER NODE
mu, logvar = encoder(graph)

# Encode the initial states (16 dimensions per node)
conditioning_vector = min_graph_sage(graph)

def reparameterize(mu, log_var):
    # Reparameterization trick to sample from a Gaussian
    # mu: mean matrix [n, 16]
    # log_var: log variance matrix [n, 16]
    std = torch.exp(0.5 * log_var)  # Standard deviation
    eps = torch.randn_like(std)  # 'random' noise
    return mu + eps * std

# Sampling from the latent space
z = reparameterize(mu, logvar)

# Concatenate the sampled latent representation with the conditioning vector
combined_features = torch.cat((z, conditioning_vector), dim=1)

reconstructed = decoder(combined_features)
original = graph.y


In [59]:
# Final MODEL

class GraphCVAE(torch.nn.Module):
    def __init__(self, encoder, min_graph_sage, decoder):
        super(GraphCVAE, self).__init__()
        self.encoder = encoder
        self.min_graph_sage = min_graph_sage
        self.decoder = decoder

    def sample_latent(self, mu, log_var, use_prior=False):
        if use_prior:
            z = torch.randn_like(mu)
        else:
            std = torch.exp(0.5 * log_var)  
            eps = torch.randn_like(std)  
            z = mu + eps * std
        return z

    def forward(self, data, trajectory=None, is_inference=False):
        if is_inference:
            # In inference mode, encode only the initial state into conditioning vector
            conditioning_vector = self.min_graph_sage(data=data)
            z = torch.randn((data.num_nodes, 16)).to(conditioning_vector.device)
        else:
            # In training mode, encode both the initial state and the full trajectory
            mu, log_var = self.encoder(data=data)
            conditioning_vector = self.min_graph_sage(data=data)
            z = self.sample_latent(mu, log_var)

        # Concatenate the latent representation and conditioning vector
        combined_features = torch.cat((z, conditioning_vector), dim=1)
        
        # Pass the combined features to the decoder
        output_sequence = self.decoder(combined_features)

        return output_sequence


# Initialize components
encoder = Encoder(num_features=4, hidden_channels=16, latent_dim=16)
min_graph_sage = MinGraphSAGE(num_features=5, hidden_channels=16, latent_dim=16)
decoder = Decoder(input_dim=32, hidden_dim=64, output_dim=4)

# Initialize model
graph_cvae = GraphCVAE(encoder, min_graph_sage, decoder)
graph_cvae(graph, is_inference=True).shape

torch.Size([6, 49, 4])

In [86]:
# Training 
# NOTE: Training only works on single graphs, not batches of graphs
# Due to this, training is very slow, and not necessarily converges
import torch
import torch.optim as optim

# Loss functions
def loss_function(recon_x, x, mu, log_var):
    MSE = torch.nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return MSE + KLD


# Training settings
epochs = 50
learning_rate = 0.01 

# Initialize components
encoder = Encoder(num_features=4, hidden_channels=16, latent_dim=16)
min_graph_sage = MinGraphSAGE(num_features=5, hidden_channels=16, latent_dim=16)
decoder = Decoder(input_dim=32, hidden_dim=64, output_dim=4)

# Initialize model
model = GraphCVAE(encoder, min_graph_sage, decoder)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Training Loop
for epoch in range(epochs):
    model.train()
    train_loss = 0
    
    # Training
    for batch in train_dataloader:
        batch = batch[0]
        # Move data to the device (CPU or GPU)
        batch = batch.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        recon_batch = model(data=batch, trajectory=batch.y)

        # Loss calculation
        mu, log_var = model.encoder(data=batch)
        loss = loss_function(recon_batch, batch.y, mu, log_var)

        # Backward pass
        loss.backward()

        # Optimization step
        optimizer.step()

        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch[0]
            # Move data to the device (CPU or GPU)
            batch = batch.to(device)

            # Forward pass
            recon_batch = model(data=batch, trajectory=batch.y)

            # Loss calculation
            mu, log_var = model.encoder(data=batch)
            loss = loss_function(recon_batch, batch.y, mu, log_var)

            val_loss += loss.item()

    # Logging
    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_dataloader.dataset):.4f}, '
          f'Validation loss: {val_loss / len(val_dataloader.dataset):.4f}')


====> Epoch: 0 Average loss: 4834.6256, Validation loss: 3012.5416
====> Epoch: 1 Average loss: 4255.2876, Validation loss: 3696.2077
====> Epoch: 2 Average loss: 4267.2710, Validation loss: 3924.9832
====> Epoch: 3 Average loss: 4720.8010, Validation loss: 4587.9356
====> Epoch: 4 Average loss: 4931.4741, Validation loss: 4179.8552
====> Epoch: 5 Average loss: 4529.7283, Validation loss: 4508.9857
====> Epoch: 6 Average loss: 6181.6733, Validation loss: 4790.1939
====> Epoch: 7 Average loss: 5705.9921, Validation loss: 5001.0549
====> Epoch: 8 Average loss: 6337.1200, Validation loss: 5230.5267
====> Epoch: 9 Average loss: 5412.9052, Validation loss: 5728.8489
====> Epoch: 10 Average loss: 6081.0069, Validation loss: 4964.9698
====> Epoch: 11 Average loss: 7334.0022, Validation loss: 8674.6104
====> Epoch: 12 Average loss: 7993.9612, Validation loss: 5331.7389
====> Epoch: 13 Average loss: 5657.9289, Validation loss: 4193.4042
====> Epoch: 14 Average loss: 5098.0322, Validation loss: 