# Develop EdgeConvBlock as a DNN layer in PyTorch

In [2]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd # this module is useful to work with tabular data
import random # this module will be used to select random samples from a collection
import os # this module will be used just to create directories in the local filesystem
from tqdm.notebook import tqdm # this module is useful to plot progress bars

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn

## kNN function
this function modifies the input tensor adding a dimension of size $k$ in order to store the informations about nearest neighbours, where $k$ is the number of nearest neighbours we want to consider.

In [57]:
def kNN_opt(k, x):
    
    x_knn = x.unsqueeze(1).expand(-1, x.shape[1], -1, -1)
    delta_phieta = x_knn[:, :, :, :2] - x_knn[:, :, :, :2].transpose(1, 2)
    _, indeces = torch.sqrt(torch.sum(delta_phieta**2, 3)**0.5).sort(dim=2, stable=True)
    knn = indeces[:,:,:k]
    x_knn = torch.gather(x_knn, 2, knn.unsqueeze(-1).expand(-1, -1, -1, x_knn.shape[-1]))
    del delta_phieta, indeces, knn, _
    return x_knn

In [58]:
def test_kNN():
    x = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
                      [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]])
    k = 2

    knn = kNN_opt(k, x)
    expected = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]],   [[3.0, 4.0], [1.0, 2.0]],  [[5.0, 6.0], [3.0, 4.0]]],
                             [[[7.0, 8.0], [9.0, 10.0]],  [[9.0, 10.0], [7.0, 8.0]], [[11.0, 12.0], [9.0, 10.0]]]])
    assert torch.allclose(knn, expected), f'Expected {expected}, but got {knn}'


test_kNN()
print('All test cases pass')

All test cases pass


## EdgeConv Block
Define the Edge Convolution operation as a `torch.nn.Module`

In [None]:
########## Edge Convolution Block ###########
# The root block of our DNN.
# Initialized by:
#   - d     the number of features
#   - k     number of nearest neighbours to consider in the concolution
#   - C     a list-like or an int with the number of neurons of the three linear layers
#   - aggr  the aggregation function, must be symmetric

class EdgeConv(nn.Module):
    
    def __init__(self, d, k, C, aggr=None):
        super().__init__()
        
        if type(C) == int:
            self.C = [C]*3
        else:
            self.C = C
        
        self.k = k

        if aggr is None:
            self.aggr = None
        else:
            self.aggr = aggr

        self.act = nn.ReLU()


        ### Shortcut path
        self.shortcut = nn.Sequential(
            nn.Conv1d(d, C[-1], 1, 1),
            nn.BatchNorm1d(C[-1])
        )

        ### Linear section, approximation of a MLP
        self.mlp = nn.Sequential(
            nn.Linear(d, C[0]),
            nn.BatchNorm1d(C[0]),
            nn.ReLU(),
            nn.Linear(C[0], C[1]),
            nn.BatchNorm1d(C[1]),
            nn.ReLU(),
            nn.Linear(C[1], C[2]),
            nn.BatchNorm1d(C[2]),
            nn.ReLU()
        )


    def kNN(self, x):

        # expand the input tensor s.t. x_knn.shape = [B, n, n, d]
        x_knn = x.unsqueeze(1).expand(-1, x.shape[1], -1, -1)

        # calculate both delta_phi and delta_eta
        delta_phieta = x_knn[:, :, :, :2] - x_knn[:, :, :, :2].transpose(1, 2)

        # calculate distances and sort them in ascending order, keep only the indeces
        _, indeces = torch.sqrt(torch.sum(delta_phieta**2, 3)**0.5).sort(dim=2, stable=True)

        # keep the indeces of k nearest neighbours and use them to sort and cut the initial tensor
        knn = indeces[:,:,:self.k]
        x_knn = torch.gather(x_knn, 2, knn.unsqueeze(-1).expand(-1, -1, -1, x_knn.shape[-1]))

        del delta_phieta, indeces, knn, _

        return x_knn    # x_knn.shape = [B, n, k, d]

    
    def linear_aggregate(self, x):

        # Here we want to define the operation which applies the mlp,
        # i.e. the linear part, to each couple of n.n. and then
        # aggregates the results.
        # Expected output shape is [B, n, d] (aggregating we collapsed the k dimension)

        return None


    def forward(self, x):
        # x.size = [B, n, d]

        # x_knn.size = [B, n, k, d]
        x_knn = self.kNN(x)

        shortcut = self.shortcut(x)
        x = self.linear_aggregate(x_knn)

        x = self.act(x + shortcut)
        
        del x_knn, shortcut
        return x