# DiffPool tutorial
The purpose of this notebook is to guide you to building a DiffPool model based on this paper [arxiv:1806.08804](https://arxiv.org/abs/1806.08804).

References:
- [YouTube](https://www.youtube.com/watch?v=Uqc3O3-oXxM)
- [Colab](https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial16/Tutorial16.ipynb)

In [1]:
import warnings
from math import ceil

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGCNConv as GCNConv
from torch_geometric.nn import GraphSAGE, SAGEConv, dense_diff_pool, global_mean_pool
from torch_geometric.nn.conv import GravNetConv, MessagePassing
from torch_geometric.utils import to_dense_adj

warnings.simplefilter("ignore", UserWarning)

# define the global base device
if torch.cuda.device_count():
    device = torch.device("cuda:0")
    print(f"Will use {torch.cuda.get_device_name(device)}")
else:
    device = "cpu"
    print("Will use cpu")

Will use cpu


Load a processed `.pt` clic file.

In [2]:
data = torch.load("../data/clic_edm4hep_2023_02_27/p8_ee_tt_ecm380/processed/data_0.pt")
print(f"num of clic events {len(data)}")

num of clic events 10000


In [3]:
# build a data loader
batch_size = 50

loader = DataLoader(data, batch_size, shuffle=True)
for batch in loader:
    print(f"A single event: \n {batch}")
    break
    
input_dim = batch.x.shape[-1]

A single event: 
 Batch(x=[6867, 17], ygen=[6867, 6], ygen_id=[6867], ycand=[6867, 6], ycand_id=[6867], batch=[6867], ptr=[51])


# Build the model

In [4]:
class GNN(torch.nn.Module):
    
    """
    The fundamental "GNN" block in the DiffPool model.
    """
    def __init__(self, in_channels, hidden_channels, out_channels, normalize=True):
        super(GNN, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))
        self.bns.append(torch.nn.LayerNorm(hidden_channels))
        
        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))
        self.bns.append(torch.nn.LayerNorm(hidden_channels))
        
        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))
        self.bns.append(torch.nn.LayerNorm(out_channels))

    def forward(self, x, adj):
 
        for step in range(len(self.convs)):
            x = F.elu(self.bns[step](self.convs[step](x, adj)))

        return x


class DiffPool(torch.nn.Module):
    
    """
    A DiffPool model based on GCNConv.
    """
    
    def __init__(self, input_dim, width, output_dim):
        super(DiffPool, self).__init__()

        max_nodes = 40
        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = GNN(input_dim, width, num_nodes)
        self.gnn1_embed = GNN(input_dim, width, width)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = GNN(width, width, num_nodes)
        self.gnn2_embed = GNN(width, width, width)

        self.gnn3_embed = GNN(width, width, width)

        self.lin1 = torch.nn.Linear(width, width)
        self.lin2 = torch.nn.Linear(width, output_dim)

    def forward(self, X):
        
        """
        We have to first 0-pad the event and convert the batch object to the good old array of
         [batch_size, num_elements, input_dim] since GCNConv doesn't accept Batch() objects :(
         
        Args 
            Big `X`: a Batch() object.
        Returns
            Array of dimension [batch_size, output_dim]
        """

        x = self.from_batch_to_dense(X.x, X.batch)
        
        # build a naive fully connected graph (`adj` stands for adjacency matrix)
        bs, num_nodes, _ = x.size()
        adj = torch.ones(bs, num_nodes, num_nodes).to(x.device)
        
        # begin with the diffpool operations
        s = self.gnn1_pool(x, adj)
        x = self.gnn1_embed(x, adj)

        x, adj, l1, e1 = dense_diff_pool(x, adj, s)

        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)

        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = F.elu(self.lin1(x))
        x = self.lin2(x)

        return x
    
    """
    helper functions needed because `dense_diff_pool` doesn't natively accept batch objects
    """
    
    def from_batch_to_dense(self, x, batch):
        for i, b in enumerate(batch.unique()):
            if i==0:
                out = self.pad_data(x[batch==b]).unsqueeze(0)
            else:
                out = torch.cat([out, self.pad_data(x[batch==b]).unsqueeze(0)])
        return out.to(x.device)

    def pad_data(self, x):
        return F.pad(input=x, pad=(0, 0, 0, 100-len(x)), mode='constant', value=0)    

In [5]:
# setup a DiffPool model with some random `width` and `output_dim`
model = DiffPool(input_dim, width=200, output_dim=6)
model.to(device)

DiffPool(
  (gnn1_pool): GNN(
    (convs): ModuleList(
      (0): DenseGCNConv(17, 200)
      (1): DenseGCNConv(200, 200)
      (2): DenseGCNConv(200, 10)
    )
    (bns): ModuleList(
      (0): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      (1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      (2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    )
  )
  (gnn1_embed): GNN(
    (convs): ModuleList(
      (0): DenseGCNConv(17, 200)
      (1): DenseGCNConv(200, 200)
      (2): DenseGCNConv(200, 200)
    )
    (bns): ModuleList(
      (0): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      (1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
      (2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
    )
  )
  (gnn2_pool): GNN(
    (convs): ModuleList(
      (0): DenseGCNConv(200, 200)
      (1): DenseGCNConv(200, 200)
      (2): DenseGCNConv(200, 3)
    )
    (bns): ModuleList(
      (0): LayerNorm((200,), eps=1e-05, elementwise_affin

In [6]:
pred = model(batch.to(device))
pred.shape

torch.Size([50, 6])