# EdgeConv (PyTorch Geometric) 

EdgeConv updates each node by aggregating messages from its neighbors:

$$
x'_i = \operatorname{AGG}_{j \in \mathcal{N}(i)} \; h_\theta\!\left(\left[x_i,\; x_j - x_i\right]\right)
$$

- $x_i$: feature vector of node $i$  
- $\mathcal{N}(i)$: neighbors of node $i$ (defined by `edge_index`)  
- $\left[x_i,\; x_j - x_i\right]$: **concatenation** (the node itself + relative difference to a neighbor)  
- $h_\theta(\cdot)$: a learnable function, typically an **MLP**  
- $\operatorname{AGG}$: neighborhood aggregation (e.g., $\max$, $\sum$, mean)

If the aggregation is **sum**, this becomes:

$$
x'_i = \sum_{j \in \mathcal{N}(i)} h_\theta\!\left(\left[x_i,\; x_j - x_i\right]\right)
$$


### `edge_index` with kNN (PyG)

EdgeConv needs a graph structure: `edge_index` (who is connected to whom).  
A common way to build this graph is **k-nearest neighbors (kNN)**:

- We choose a feature space for distance (PMT/DOM positions in our case).
- For each node, we find its **k nearest neighbors** within the same event (using `batch`).
- The result is `edge_index` with shape **[2, E]**, where each column is one directed edge `(source -> target)`. (here, source is the neighbour, target is the node. This is directed arrow)

Key idea:
- `x_knn = x[:, features_subset]` decides **which features** are used for the distance computation.
- `batch` prevents connecting nodes from different events.


### Learnable Function Example: MLP

EdgeConv is a **message passing** layer. It needs a learnable function to turn a pair
of nodes (a node and one of its neighbors) into a **message**.  
That learnable function is the **MLP** we pass in here.

For each directed edge (neighbor relation) $j \rightarrow i$:

1. Take the target node features $x_i$  
2. Take the relative difference to the neighbor $x_j - x_i$  
3. Concatenate them to form the MLP input:
   $$
   z_{ij} = \left[x_i,\; x_j - x_i\right]
   $$
   after the concatenation, the feature number doubles.
5. Compute a message with the MLP:
   $$
   m_{ij} = \mathrm{MLP}(z_{ij})
   $$
6. Aggregate all neighbor messages (e.g. max/sum/mean) to update the node:
   $$
   x'_i = \operatorname{AGG}_{j \in \mathcal{N}(i)} m_{ij}
   $$



In [1]:
import torch
from torch_geometric.nn.pool import knn_graph

In [2]:
# --- Fake mini-batch: 2 events ---
# Event 0 has 4 nodes, Event 1 has 6 nodes -> total N=10
n0, n1 = 4, 6
N = n0 + n1

In [3]:
# batch[i] = event_id of node i
batch = torch.tensor([0]*n0 + [1]*n1)

In [4]:
batch

tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])

In [5]:
# Node features: let's create [N, 8] : pmt_x, pmt_y, pmt_z, pmt_i, pmt_j, pmt_k, q, t (i, j, k: angle; x, y, z: position)
# First 3 features = "position" (used for kNN)
pos = torch.randn(N, 3) * 5.0
other_5 = torch.randn(N, 5)
x = torch.cat([pos, other_5], dim=1)

print("x shape:", x.shape)         # [N, 8]
print("batch shape:", batch.shape) # [N]

x shape: torch.Size([10, 8])
batch shape: torch.Size([10])


In [6]:
x

tensor([[ -0.1572,  -1.1941, -10.0239,  -0.5515,  -0.2864,   1.2036,   0.2426,
           0.3425],
        [  6.5740,   3.3514,   3.4227,   2.1173,  -0.5528,  -0.0833,   0.5151,
           0.0869],
        [  3.7905,  -6.8601,  -0.0668,  -0.9295,   1.4618,  -0.7746,   0.1659,
           0.3784],
        [  1.5559,  -2.7138,  -5.7097,  -0.6423,   1.3420,  -0.7579,   1.2664,
          -0.3835],
        [ -3.6128,   1.0768,   5.0375,  -0.6368,   0.1879,  -0.9061,  -0.1070,
          -0.6018],
        [ -6.1118,   8.2047,  -6.6751,  -2.0629,  -0.2136,   0.1861,  -0.3162,
           0.2089],
        [ -0.9872,   2.5065,  -6.6366,  -3.0602,   0.6658,   0.2832,   0.2519,
          -1.9083],
        [ -4.6310,  -0.3009,   5.3977,   0.9575,  -1.7559,  -0.1788,  -0.5045,
           1.0550],
        [ -0.0400,  -3.3818,   4.0182,   1.8160,   0.0778,  -0.5094,  -0.0233,
           0.2756],
        [ -4.3483,   0.6645,   8.6566,   0.0539,   0.7440,   0.1583,  -1.3908,
          -0.5963]])

In [7]:
# --- Build edges with kNN using only the first 3 features (pos) ---
k = 3
edge_index = knn_graph(
    x=x[:, :3],      # use position only for distance
    k=k,
    batch=batch,
    loop=False
)

In [8]:

print("edge_index shape:", edge_index.shape)  # [2, N*k] (directed edges)


edge_index shape: torch.Size([2, 30])


In [9]:
edge_index
# you can see that there is no connection between different events

tensor([[3, 2, 1, 2, 3, 0, 3, 1, 0, 0, 2, 1, 7, 9, 8, 6, 4, 7, 5, 4, 8, 4, 9, 8,
         7, 4, 9, 7, 4, 8],
        [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7,
         8, 8, 8, 9, 9, 9]])

> **Note (`loop=False`)**: This disables self-loops in the kNN graph.  
> With `loop=False`, edges like `i → i` are **not** added (a node is not considered its own neighbor).  
> With `loop=True`, self-loop edges `i → i` **are included**.


## Apply EdgeConv (one layer)

We now have:
- `x` with shape `[N, F]` (node features)
- `edge_index` with shape `[2, E]` (kNN graph)

EdgeConv needs an MLP `h_θ` that maps the concatenated vector  
`[x_i, x_j - x_i]` (dimension `2F`) into an output embedding.

Then it aggregates messages from neighbors (here: `max`) to produce `x_out`.


In [10]:
x

tensor([[ -0.1572,  -1.1941, -10.0239,  -0.5515,  -0.2864,   1.2036,   0.2426,
           0.3425],
        [  6.5740,   3.3514,   3.4227,   2.1173,  -0.5528,  -0.0833,   0.5151,
           0.0869],
        [  3.7905,  -6.8601,  -0.0668,  -0.9295,   1.4618,  -0.7746,   0.1659,
           0.3784],
        [  1.5559,  -2.7138,  -5.7097,  -0.6423,   1.3420,  -0.7579,   1.2664,
          -0.3835],
        [ -3.6128,   1.0768,   5.0375,  -0.6368,   0.1879,  -0.9061,  -0.1070,
          -0.6018],
        [ -6.1118,   8.2047,  -6.6751,  -2.0629,  -0.2136,   0.1861,  -0.3162,
           0.2089],
        [ -0.9872,   2.5065,  -6.6366,  -3.0602,   0.6658,   0.2832,   0.2519,
          -1.9083],
        [ -4.6310,  -0.3009,   5.3977,   0.9575,  -1.7559,  -0.1788,  -0.5045,
           1.0550],
        [ -0.0400,  -3.3818,   4.0182,   1.8160,   0.0778,  -0.5094,  -0.0233,
           0.2756],
        [ -4.3483,   0.6645,   8.6566,   0.0539,   0.7440,   0.1583,  -1.3908,
          -0.5963]])

In [11]:
import torch
import torch.nn as nn
from torch_geometric.nn import EdgeConv

F = x.size(1) 
print("F =", F)


F = 8


In [12]:
# EdgeConv uses concat([x_i, x_j - x_i]) -> input dim = 2F
mlp = nn.Sequential(
    nn.Linear(2*F, 32),
    nn.ReLU(),
    nn.Linear(32, 256)   # used in the article
)
# PyTorch's default initialization is used here.


conv = EdgeConv(nn=mlp, aggr="sum")   # try "sum" later
x_out = conv(x, edge_index)

print("x_out shape:", x_out.shape)    # [N, 256]
print("First node before:\n", x[0])
print("First node after:\n", x_out[0])


x_out shape: torch.Size([10, 256])
First node before:
 tensor([ -0.1572,  -1.1941, -10.0239,  -0.5515,  -0.2864,   1.2036,   0.2426,
          0.3425])
First node after:
 tensor([-1.0370e+00,  5.4676e+00, -2.6796e+00, -3.2898e-03,  2.0559e+00,
         1.0472e+00, -3.2227e+00, -2.0603e+00, -2.3729e-01,  8.9579e-01,
        -1.2599e+00,  4.1485e+00,  1.2905e+00,  4.1283e+00,  3.2679e+00,
        -1.1494e-01,  2.7642e+00,  7.1270e-01, -8.6909e-03,  3.5283e+00,
         1.2398e+00, -2.4999e+00,  3.7444e+00,  1.7295e+00, -6.8036e-01,
         2.5756e+00, -5.6287e+00,  7.7296e-01, -1.0315e-01,  1.7377e+00,
        -3.1972e+00, -2.2684e+00,  1.7628e+00, -4.6151e-01,  1.3792e-01,
         2.3389e+00,  2.6050e+00, -1.5212e+00,  2.5155e+00, -5.7116e-02,
         2.0707e+00, -7.3036e-01, -4.2728e+00,  4.3175e-01, -3.0296e+00,
         4.9928e-02, -1.1414e-01,  9.0156e-01, -5.2839e+00,  3.4741e+00,
        -1.2232e+00, -1.7794e+00, -3.2967e+00, -2.2214e+00,  4.3977e+00,
        -2.3234e+00, -3.76

In [13]:
x.shape

torch.Size([10, 8])

In [14]:
x_out.shape

torch.Size([10, 256])

In [15]:
# Each node should appear ~k times as a target (second row), because edges are directed.
targets = edge_index[1]
counts = torch.bincount(targets, minlength=x.size(0))

print("Target counts per node:", counts.tolist())

Target counts per node: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]


# DynEdgeConv

### What is a `LightningModule`?

A `LightningModule` is the core class in **PyTorch Lightning**. It is a structured way to define a model *and* the training logic in one place, so you don’t have to write the full PyTorch training loop manually.

Typically, a `LightningModule` contains:

- **`forward()`**: how the model produces outputs from inputs  
- **`training_step(batch, batch_idx)`**: how to compute the training loss for one batch  
- **`validation_step(...)` / `test_step(...)`**: evaluation logic  
- **`configure_optimizers()`**: how to set up the optimizer (and optional schedulers)

In short:  
**PyTorch** gives you the building blocks,  
**PyTorch Lightning** (via `LightningModule`) standardizes the training loop and organizes your code.


In [16]:
## this is from GraphNet source codes


from typing import Any, Callable, Optional, Sequence, Union

import torch
from torch import Tensor
from torch_geometric.nn import EdgeConv
from torch_geometric.nn.pool import knn_graph
from torch_geometric.typing import Adj
from pytorch_lightning import LightningModule


class DynEdgeConv(EdgeConv, LightningModule):
    """Dynamical edge convolution layer."""

    def __init__(
        self,
        nn: Callable,
        aggr: str = "add", # sum is used in the article
        nb_neighbors: int = 8, # 8 is used in the article
        features_subset: Optional[Union[Sequence[int], slice]] = None,
        **kwargs: Any,
    ):
        """Construct `DynEdgeConv`.

        Args:
            nn: The MLP/torch.Module to be used within the `EdgeConv`.
            aggr: Aggregation method to be used with `EdgeConv`.
            nb_neighbors: Number of neighbours to be clustered after the
                `EdgeConv` operation.
            features_subset: Subset of features in `Data.x` that should be used
                when dynamically performing the new graph clustering after the
                `EdgeConv` operation. Defaults to all features.  
            **kwargs: Additional features to be passed to `EdgeConv`.
        """
        # Check(s)
        if features_subset is None:
            features_subset = slice(None)  # Use all features
        assert isinstance(features_subset, (list, slice))

        # Base class constructor
        super().__init__(nn=nn, aggr=aggr, **kwargs)

        # Additional member variables
        self.nb_neighbors = nb_neighbors
        self.features_subset = features_subset

    def forward(
        self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None
    ) -> Tensor:
        """Forward pass."""
        # Standard EdgeConv forward pass
        x = super().forward(x, edge_index)

        # compute new adjacency
        edge_index = knn_graph(
            x=x[:, self.features_subset],
            k=self.nb_neighbors,
            batch=batch,
        ).to(self.device)

        return x, edge_index

## only forward? should i add other stuff?

# init ederken features_subset'I None birak. Cunku zaten ilk layerda 3 olmasini onceden hesapliyosun. Burada, ikinci (ve sonraki) layerlar icin
# veriliyor bu parametre.
# nb_neighbors: bunu sanirim makalede hep 8 secmisler. Ama sen kendin karar ver buna. 

In [17]:
# Fake data: 2 event, 5 node each.
n0, n1 = 5, 5
n = n0 + n1

batch = torch.tensor([0]*n0 + [1]*n1)


In [18]:
batch
# indexes: pulse id
# Not: Bu “pulse id” sadece o mini-batch içinde geçerli (dataset genelinde kalıcı bir ID değil).

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

In [19]:
# Node features: let's create [N, 8] : pmt_x, pmt_y, pmt_z, pmt_i, pmt_j, pmt_k, q, t (i, j, k: angle; x, y, z: position)
# First 3 features = "position" (used for kNN)
pos = torch.randn(N, 3) * 5.0
other_5 = torch.randn(N, 5)
x = torch.cat([pos, other_5], dim=1)

In [20]:
pos

tensor([[-1.3131, -5.2460,  4.5672],
        [ 4.9336, -4.5426,  3.3155],
        [-4.2067,  3.4617,  3.2696],
        [ 0.3581, -5.7185,  0.3676],
        [-0.9504, -2.5171,  1.5353],
        [ 1.1992, -1.7472,  4.5736],
        [-4.4386, -1.0320, -5.0398],
        [ 1.5857, -3.4883, -3.0318],
        [-3.5653, -7.1069, -0.4550],
        [-3.3526,  3.7619, -7.6058]])

In [21]:
other_5

tensor([[-0.1102, -1.0765,  1.2198, -0.3551, -1.3987],
        [ 0.5555, -1.7933,  1.4585, -0.2492, -0.7961],
        [-1.5236,  1.3001, -0.1664, -1.5490, -0.2919],
        [-0.0208, -0.3597,  0.0589,  1.0251,  0.1441],
        [-0.6544, -0.2587,  0.0216, -1.5524, -0.4753],
        [ 0.9003,  1.8899,  1.1911,  1.4694, -0.1391],
        [-1.9213, -0.3719,  2.1403, -0.0090,  0.7390],
        [ 0.5260, -1.5670, -0.4099, -1.9789,  0.0352],
        [ 0.5567, -0.1976,  0.4111,  0.6739, -2.0867],
        [-1.2404,  0.1228,  0.4560, -0.8187, -1.3543]])

In [22]:
x

tensor([[-1.3131, -5.2460,  4.5672, -0.1102, -1.0765,  1.2198, -0.3551, -1.3987],
        [ 4.9336, -4.5426,  3.3155,  0.5555, -1.7933,  1.4585, -0.2492, -0.7961],
        [-4.2067,  3.4617,  3.2696, -1.5236,  1.3001, -0.1664, -1.5490, -0.2919],
        [ 0.3581, -5.7185,  0.3676, -0.0208, -0.3597,  0.0589,  1.0251,  0.1441],
        [-0.9504, -2.5171,  1.5353, -0.6544, -0.2587,  0.0216, -1.5524, -0.4753],
        [ 1.1992, -1.7472,  4.5736,  0.9003,  1.8899,  1.1911,  1.4694, -0.1391],
        [-4.4386, -1.0320, -5.0398, -1.9213, -0.3719,  2.1403, -0.0090,  0.7390],
        [ 1.5857, -3.4883, -3.0318,  0.5260, -1.5670, -0.4099, -1.9789,  0.0352],
        [-3.5653, -7.1069, -0.4550,  0.5567, -0.1976,  0.4111,  0.6739, -2.0867],
        [-3.3526,  3.7619, -7.6058, -1.2404,  0.1228,  0.4560, -0.8187, -1.3543]])

In [23]:
print("x shape:", x.shape)  # [10, 8] : 10 node (for 2 event), 6 features
print("batch:", batch)


x shape: torch.Size([10, 8])
batch: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])


In [24]:
# 3) Initial adjacency (edge_index) 
edge_index0 = knn_graph(x=pos, k=3, batch=batch, loop=False)
print("edge_index0 shape:", edge_index0.shape)  # [2, n*k] ~ [2, 30] : 10 nodes each having 3 neighbours

edge_index0 shape: torch.Size([2, 30])


In [25]:
print("edge_index0:\n", edge_index0)

edge_index0:
 tensor([[4, 3, 1, 3, 0, 4, 4, 0, 3, 4, 0, 1, 3, 0, 1, 7, 8, 6, 9, 7, 8, 8, 6, 5,
         7, 6, 5, 6, 7, 8],
        [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7,
         8, 8, 8, 9, 9, 9]])


In [26]:
# input size of the mlp = 2*F
mlp = nn.Sequential(
    nn.Linear(16, 32),
    nn.ReLU(),
    nn.Linear(32, 256)   
)

In [27]:
mlp

Sequential(
  (0): Linear(in_features=16, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=256, bias=True)
)

In [28]:
layer = DynEdgeConv(
    nn=mlp,
    nb_neighbors=3,                
)

In [29]:
x1, edge_index1 = layer(x, edge_index0, batch=batch)


In [30]:
x1

tensor([[ 1.1364, -0.8595, -0.5129,  ..., -0.4237, -0.4554, -1.7924],
        [-1.7836, -0.6169, -0.9150,  ...,  1.1913,  1.2121, -3.4018],
        [-0.1512,  0.9915,  1.5075,  ..., -1.6531, -1.6588,  0.4486],
        ...,
        [-2.1455, -2.0212, -1.3974,  ..., -0.6470, -0.2686, -1.9106],
        [-2.9431, -2.2061,  0.2430,  ..., -2.6065, -0.0694, -0.5376],
        [-4.3823, -1.0552, -1.1999,  ..., -1.6722, -2.5916,  0.7589]],
       grad_fn=<ScatterAddBackward0>)

In [31]:
print("\n--- AFTER DynEdgeConv ---")
print("x1 shape:", x1.shape)               # [n, out_dim] = [10, 256]


--- AFTER DynEdgeConv ---
x1 shape: torch.Size([10, 256])


In [32]:
print("\n--- AFTER DynEdgeConv ---")
print("edge_index1 shape:", edge_index1.shape)  # [2, n*k] ~ [2, 30]



--- AFTER DynEdgeConv ---
edge_index1 shape: torch.Size([2, 30])


In [33]:
edge_index1

tensor([[4, 3, 1, 3, 4, 0, 4, 0, 3, 4, 0, 1, 0, 3, 1, 7, 6, 8, 9, 7, 8, 6, 9, 8,
         6, 7, 9, 6, 7, 8],
        [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7,
         8, 8, 8, 9, 9, 9]])

PyTorch’ta bir layer/modeli `layer(...)` diye çağırdığında aslında **`__call__`** çalışır. Bu `__call__`, senin yazdığın **`forward(...)`** metodunu çağırır. Ayrıca `__call__` araya PyTorch’un mekaniklerini koyar: autograd grafını kurar (gradient’ler için), `training/eval` modunu uygular (dropout/batchnorm gibi), forward/backward hook’ları çalıştırır ve (varsa) mixed precision/autocast gibi şeyleri yönetir. Kısaca: **`forward` hesabın tarifi**, **`__call__` o hesabı PyTorch kurallarıyla çalıştıran mekanizma**.
