# Integrate with Pytorch Geometric

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import yaml

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append("../")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
from architectures.EquivariantGNN.Models.legnn import LEGNN

# Roadmap

- [X] Pull in toy quickstart
- [X] Test batch 1 of toy
- [X] Test batch N of toy
- [ ] Tweak edge conv to fit LEGNN conv
- [ ] Tweak forward to fit LEGNN and ParticleNet

## PyG Toy

In [8]:
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='.', name='Cora')

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GCN(dataset.num_features, 16, dataset.num_classes)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [3]:
from torch_geometric.data import Data, DataLoader

In [4]:
random_graphs = []
for i in range(100):
    x = torch.rand((100,3))
    num_edges = torch.randint(1000, (1,))
    edges = torch.randint(len(x), (2, num_edges))
    y = torch.round(torch.rand(num_edges))
    random_graphs.append(Data(x=x, edge_index = edges, y=y))

In [5]:
dataloader = DataLoader(random_graphs, shuffle=False, batch_size=2)

In [6]:
for graph in dataloader:
    
    print(graph.edge_index)
    
    break

tensor([[ 88,   3,  57,  ..., 137, 193, 168],
        [ 49,  81,  85,  ..., 105, 110, 124]])


## Test custom message passing

In [13]:
from torch_geometric.nn import MessagePassing
from torch.nn import Sequential as Seq, Linear, ReLU

In [14]:
class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        print("Propagating")
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        print("Calculating message")
        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)
    
    def update(self, aggr_out):
        
        print("Aggregating")
        return aggr_out

In [15]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = EdgeConv(in_channels, hidden_channels)
        self.conv2 = EdgeConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = VanillaGNN(3, 16, 2)

In [16]:
for graph in dataloader:
    output = model(graph.x, graph.edge_index)
    
    break

Propagating
Calculating message
Aggregating
Propagating
Calculating message
Aggregating


In [60]:
output.shape

torch.Size([200, 7])

## LEGNN into PyG

- [ ] 

In [3]:
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch.nn import Sequential as Seq, Linear, ReLU
import torch.nn as nn

class L_GCL(MessagePassing):
    def __init__(self, input_feature_dim, message_dim, output_feature_dim, edge_feature_dim, activation = nn.SiLU()):
        super().__init__(aggr='add') #  "Max" aggregation.
        
        radial_dim = 1  # Only one number is needed to specify Minkowski distance
        coordinate_dim = 4
        self.message_dim = message_dim

        # The MLP used to calculate messages
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * input_feature_dim + radial_dim + edge_feature_dim, message_dim),
            activation,
            nn.Linear(message_dim, message_dim),
            nn.Softsign()
            #activation
        )

        # The MLP used to update the feature vectors h_i
        self.feature_mlp = nn.Sequential(
            nn.Linear(input_feature_dim + message_dim, message_dim),
            activation,
            nn.Linear(message_dim, output_feature_dim),
            nn.Softsign()
        )

        # Setup randomized weights
        self.layer = nn.Linear(message_dim, 1, bias = False)
        torch.nn.init.xavier_uniform_(self.layer.weight, gain = 0.001)

        # The MLP used to update coordinates (node embeddings) x_i
        self.coordinate_mlp = nn.Sequential(
            nn.Linear(message_dim, message_dim),
            activation,
            self.layer
        )

        self.coordinate_linear_combination_mlp = nn.Linear(2 * coordinate_dim, coordinate_dim, bias = False)

    def forward(self, x, h, edge_index, edge_attribute = None):
        
        radial, _ = self.compute_radials(edge_index, x)

        return self.propagate(edge_index, x=x, h=h, radial=radial)

    def message(self, x_i, x_j, h_i, h_j, radial):

        
        h_messages = self.compute_messages(h_i, h_j, radial)
        x_messages = (x_i - x_j)*self.coordinate_mlp(h_messages)
        
        return torch.cat([h_messages, x_messages], axis=1)
    
    def update(self, aggr, x, h):
        h_next = self.feature_mlp(torch.cat([aggr[:, :self.message_dim], h], axis=1) )
        x_next = x + aggr[:, self.message_dim:]
    
        return h_next, x_next
    
    
    def compute_messages(self, source, target, radial, edge_attribute = None):
        """
        Calculates the messages to send between two nodes 'target' and 'source' to be passed through the network.
        The message is computed via an MLP of Lorentz invariants.

        :param source: The source node's feature vector h_i
        :param target: The target node's feature vector h_j
        :param radial: The Minkowski distance between the source and target's coordinates
        :param edge_attribute: Features at the edge connecting the source and target nodes
        :return: The message m_{ij}
        """
        
        if edge_attribute is None:
            message_inputs = torch.cat([source, target, radial], dim = 1)  # Setup input for computing messages through MLP
        else:
            message_inputs = torch.cat([source, target, radial, edge_attribute], dim = 1)  # Setup input for computing messages through MLP

        out = self.edge_mlp(message_inputs)  # Apply \phi_e to calculate the messages
        return out
    
    @staticmethod
    def compute_radials(edge_index, x):
        """
        Calculates the Minkowski distance (squared) between coordinates (node embeddings) x_i and x_j

        :param edge_index: Array containing the connection between nodes
        :param x: The coordinates (node embeddings)
        :return: Minkowski distances (squared) and coordinate differences x_i - x_j
        """

        row, col = edge_index
        coordinate_differences = x[row] - x[col]
        minkowski_distance_squared = coordinate_differences ** 2
        minkowski_distance_squared[:, 0] = -minkowski_distance_squared[:, 0]  # Place minus sign on time coordinate as \eta = diag(-1, 1, 1, 1)
        radial = torch.sum(minkowski_distance_squared, 1).unsqueeze(1)
        return radial, coordinate_differences


In [4]:
from architectures.EquivariantGNN.egnn_base import EGNNBase
from torch import nn

class LEGNN(EGNNBase):
    """
    The main network used for Lorentz group equivariance consisting of several layers of L_GCLs
    """

    def __init__(self, hparams):
#     input_feature_dim, message_dim, output_feature_dim, edge_feature_dim,
#                  device = 'cpu', activation = nn.SiLU(), n_layers = 4):
        """
        Sets up the equivariant network and creates the necessary L_GCL layers

        :param input_feature_dim: The amount of numbers needed to specify a feature inputted into the LEGNN
        :param message_dim: The amount of numbers needed to specify a message passed through the LEGNN
        :param output_feature_dim: The amount of numbers needed to specify the updated feature after passing through the LEGNN
        :param edge_feature_dim: The amount of numbers needed to specify an edge attribute a_{ij}
        :param device: Specification on whether the cpu or gpu is to be used
        :param activation: The activation function used as the main non-linearity throughout the LEGNN
        :param n_layers: The number of layers the LEGNN network has
        """

        super().__init__(hparams)
        self.message_dim = hparams["message_dim"]
        self.activation = getattr(nn, hparams["activation"])
        self.n_layers = hparams["n_layers"]
        self.feature_in = nn.Linear(hparams["input_feature_dim"], self.message_dim)  # Initial mixing of features
        self.feature_out = nn.Linear(self.message_dim, hparams["output_feature_dim"])  # Final mixing of features to yield desired output

        for i in range(0, hparams["n_layers"]):
            self.add_module("gcl_%d" % i, L_GCL(self.message_dim, self.message_dim, self.message_dim,
                                                hparams["edge_feature_dim"], activation = self.activation()))

    def forward(self, x, edges, edge_attribute = None):
        
        h = self.compute_initial_feature(edges, x)
        h = self.feature_in(h.unsqueeze(1))
        for i in range(0, self.n_layers):
            h, x = self._modules["gcl_%d" % i](x, h, edges, edge_attribute = edge_attribute)
        h = self.feature_out(h)
        return h, x
    
    @staticmethod
    def compute_initial_feature(edge_index, x):
        """
        Calculates the Minkowski distance (squared) between coordinates (node embeddings) x_i and x_j

        :param edge_index: Array containing the connection between nodes
        :param x: The coordinates (node embeddings)
        :return: Minkowski distances (squared) and coordinate differences x_i - x_j
        """

        momentum_squared = x**2
        momentum_squared[:, 0] = -momentum_squared[:, 0]
        minkowski_magnitude = torch.sum(momentum_squared, 1)
        
        return minkowski_magnitude

## Load model

In [5]:
with open("configs/lgnn_config.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [6]:
model = LEGNN(hparams)

## Train

In [None]:
logger = WandbLogger(project="LorentzNet", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=30, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name        | Type   | Params
---------------------------------------
0 | feature_in  | Linear | 256   
1 | feature_out | Linear | 129   
2 | gcl_0       | L_GCL  | 115 K 
3 | gcl_1       | L_GCL  | 115 K 
4 | gcl_2       | L_GCL  | 115 K 
5 | gcl_3       | L_GCL  | 115 K 
---------------------------------------
462 K     Trainable params
0         Non-trainable params
462 K     Total params
1.851     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Validate