# Integrate with Pytorch Geometric

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import yaml

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append("../")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
from architectures.EquivariantGNN.Models.legnn import LEGNN

# Roadmap

- [X] Pull in toy quickstart
- [X] Test batch 1 of toy
- [X] Test batch N of toy
- [ ] Tweak edge conv to fit LEGNN conv
- [ ] Tweak forward to fit LEGNN and ParticleNet

## PyG Toy

In [8]:
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='.', name='Cora')

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GCN(dataset.num_features, 16, dataset.num_classes)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [3]:
from torch_geometric.data import Data, DataLoader

In [4]:
random_graphs = []
for i in range(100):
    x = torch.rand((100,3))
    num_edges = torch.randint(1000, (1,))
    edges = torch.randint(len(x), (2, num_edges))
    y = torch.round(torch.rand(num_edges))
    random_graphs.append(Data(x=x, edge_index = edges, y=y))

In [5]:
dataloader = DataLoader(random_graphs, shuffle=False, batch_size=2)

In [6]:
for graph in dataloader:
    
    print(graph.edge_index)
    
    break

tensor([[ 88,   3,  57,  ..., 137, 193, 168],
        [ 49,  81,  85,  ..., 105, 110, 124]])


## Test custom message passing

In [13]:
from torch_geometric.nn import MessagePassing
from torch.nn import Sequential as Seq, Linear, ReLU

In [14]:
class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        print("Propagating")
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        print("Calculating message")
        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)
    
    def update(self, aggr_out):
        
        print("Aggregating")
        return aggr_out

In [15]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = EdgeConv(in_channels, hidden_channels)
        self.conv2 = EdgeConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = VanillaGNN(3, 16, 2)

In [16]:
for graph in dataloader:
    output = model(graph.x, graph.edge_index)
    
    break

Propagating
Calculating message
Aggregating
Propagating
Calculating message
Aggregating


In [60]:
output.shape

torch.Size([200, 7])

## LEGNN into PyG

- [ ] 

## Load model

In [5]:
with open("configs/lgnn_config.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [6]:
model = LEGNN(hparams)

## Train

In [None]:
logger = WandbLogger(project="LorentzNet", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=30, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name        | Type   | Params
---------------------------------------
0 | feature_in  | Linear | 256   
1 | feature_out | Linear | 129   
2 | gcl_0       | L_GCL  | 115 K 
3 | gcl_1       | L_GCL  | 115 K 
4 | gcl_2       | L_GCL  | 115 K 
5 | gcl_3       | L_GCL  | 115 K 
---------------------------------------
462 K     Trainable params
0         Non-trainable params
462 K     Total params
1.851     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Validate