# Edge classification

This notebooks shows how to classify edges of a graph. In many GNN tracking approaches, we start from an initial graph (e.g., built from a point cloud with the strategy described in `009_build_graphs_ml.ipynb`). We then try to falsify all edges that connected hits of two different particles. If edge classification (EC) would be perfect, we could then reconstruct tracks as connected components of the graph.
For our object condensation approach, EC is only an auxiliary step. Edges are only considered for message passing but are not important for the final decision on how tracks look. However, EC is still important to help the model to learn quickly.

For background on pytorch lightning, see `009_build_graphs_ml.ipynb`.

In [76]:
from pytorch_lightning import Trainer
from torch import nn
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
import torch
from functools import partial

from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt

from gnn_tracking.metrics.losses.ec import EdgeWeightFocalLoss
from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.training.ec import ECModule

from gnn_tracking.utils.loading import TrackingDataModule


from gnn_tracking.utils.versioning import assert_version_geq

assert_version_geq("23.12.0")

We can either directly load graphs (from disk), or we load point clouds and build edges on the fly using the module from `009_build_graphs_ml.ipynb`.

## From on-disk graphs

### 1. Setting up the data

If you are not working on Princeton's `della`, you can download these example graphs [here](https://cernbox.cern.ch/s/4xYL99cd7zNe0VK). Note that this is simplified data (pt > 1 GeV truth cut) and a single event has been broken up into 32 sectors.

In [85]:
dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all"
        ],
        stop=28_000,
        # If you run into memory issues, reduce this
        batch_size=10,
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all"
        ],
        start=28_000,
        stop=28_100,
    ),
    identifier="graphs_v1",
)

### 2. Defining the module

In [86]:
class SillyEC(nn.Module, HyperparametersMixin):
    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 12,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.node_in_dim = node_in_dim
        self.edge_in_dim = edge_in_dim
        self.hidden_dim = hidden_dim

        self.fcnn = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, data):
        w = self.fcnn(data.edge_attr).squeeze()
        return {"W": w}

In [87]:
model = SillyEC(node_in_dim=6, edge_in_dim=4, hidden_dim=128)

### 2. Setting up the loss functions and the lightning module

In [88]:
lmodel = ECModule(
    model=model,
    loss_fct=EdgeWeightFocalLoss(alpha=0.3),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
)

### 3. Starting training

In [89]:
trainer = Trainer(
    max_steps=100,
    val_check_interval=100,
    accelerator="cpu",
    log_every_n_steps=1,
    callbacks=[PrintValidationMetrics()],
)
trainer.fit(model=lmodel, datamodule=dm)

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m[16:06:55] INFO: DataLoader will load 28000 graphs (out of 28800 available).[0m
[36m[16:06:55] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21974_s9.pt[0m
[32m[16:06:56] INFO: DataLoader will load 100 graphs (out of 28800 available).[0m
[36m[16:06:56] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/objec

Sanity Checking: |                                                                                                                                                                                                             | 0/? [00:00<?, ?it/s]

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


                                                                                                                                                                                                                                                     

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Epoch 0:   4%|█████▉                                                                                                                                                                 | 100/2800 [00:24<11:10,  4.03it/s, v_num=8, total_train=0.0646]


                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                     
[3m              Validation epoch=0              [0m                                                                                                                                                                                               
┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃[1m [0m[1mMetric                [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1m  Error[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│ max_ba                 │ 0.81276 │ 0.00218 │
│

`Trainer.fit` stopped: `max_steps=100` reached.


Epoch 0:   4%|█████▉                                                                                                                                                                 | 100/2800 [00:32<14:50,  3.03it/s, v_num=8, total_train=0.0646]


## With graphs built on-the-fly from point clouds

Step 1: Configure data module to load point clouds (rather than graphs).
Step 2: Add `MLGraphConstructionFromChkpt` as preproc.

In [None]:
lmodel = ECModule(
    model=model,
    loss_fct=EdgeWeightFocalLoss(alpha=0.3),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    preproc=MLGraphConstructionFromChkpt(
        ml_class_name="gnn_tracking.models.graph_construction.GraphConstructionFCNN",
        ml_chkpt_path="/path/to/your/checkpoint",
    ),
)

Instead of `MLGraphConstructionFromChkpt` you can also take a look at `MLGraphConstruction` that simply takes a model (that you can instantiate in any way).