# Develop EdgeConvBlock as a DNN layer in PyTorch

In [2]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd # this module is useful to work with tabular data
import random # this module will be used to select random samples from a collection
import os # this module will be used just to create directories in the local filesystem
from tqdm.notebook import tqdm # this module is useful to plot progress bars

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn

## kNN function
this function provides for each jet a matrix $n\times k$, where $n$ is the number of particles in the jet and $k$ is the number of nearest neighbours we want to consider.

## EdgeConv Block
Define the Edge Convolution operation as a `torch.nn.Module`

In [None]:
########## Edge Convolution Block ###########
# The root block of our DNN.
# Initialized by:
#   - d     the number of features
#   - k     number of nearest neighbours to consider in the concolution
#   - C     a list-like or an int with the number of neurons of the three linear layers
#   - aggr  the aggregation function, must be symmetric

class EdgeConv(nn.Module):
    
    def __init__(self, d, k, C, aggr=None):
        super().__init__()
        
        if type(C) == int:
            self.C = [C]*3
        else:
            self.C = C
        
        self.k = k

        if aggr is None:
            self.aggr = None
        else:
            self.aggr = aggr


        ### Shortcut path
        self.shortcut = nn.Sequential(
            nn.Conv1d(d, C[-1], 1, 1),
            nn.BatchNorm1d(C[-1])
        )

        ### Linear section, approximation of a MLP
        self.mlp = nn.Sequential(
            nn.Linear(d, C[0]),
            nn.BatchNorm1d(C[0]),
            nn.ReLU(),
            nn.Linear(C[0], C[1]),
            nn.BatchNorm1d(C[1]),
            nn.ReLU(),
            nn.Linear(C[1], C[2]),
            nn.BatchNorm1d(C[2]),
            nn.ReLU()
        )

    def kNN(x, k):

        coords = x[:, :, :2]
        knn_list = []
        for i in range(coords.shape[1]):
            delta_phi = coords[:, i, 0] - coords[:, :, 0]
            delta_eta = coords[:, i, 1] - coords[:, :, 1]
            distances = torch.sqrt(delta_phi**2 + delta_eta**2)
            _, indeces = distances.sort(dim=1, descending=True)
            indeces = indeces[0][:k] # keep onli the first k nearest neighbourn
            knn_list.append(indeces)

        knn = torch.stack(knn_list, dim=1)
        del knn_list
        return knn
        
    def forward(self, x):
        # x.size = [B, n, d]
        knn = self.kNN(x)
        return x