# Edge classification

This notebooks shows how to classify edges of a graph. In many GNN tracking approaches, we start from an initial graph (e.g., built from a point cloud with the strategy described in `009_build_graphs_ml.ipynb`). We then try to falsify all edges that connected hits of two different particles. If edge classification (EC) would be perfect, we could then reconstruct tracks as connected components of the graph.
For our object condensation approach, EC is only an auxiliary step. Edges are only considered for message passing but are not important for the final decision on how tracks look. However, EC is still important to help the model to learn quickly.

For background on pytorch lightning, see `009_build_graphs_ml.ipynb`.

In [1]:
from pytorch_lightning import Trainer
from torch import nn
from pytorch_lightning.core.mixins import HyperparametersMixin
import torch
from functools import partial

from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt

from gnn_tracking.metrics.losses import EdgeWeightFocalLoss
from gnn_tracking.training.ec import ECModule

from gnn_tracking.utils.loading import TrackingDataModule

We can either directly load graphs (from disk), or we load point clouds and build edges on the fly using the module from `009_build_graphs_ml.ipynb`.

## From on-disk graphs

### 1. Setting up the data

If you are not working on Princeton's `della`, you can download these example graphs [here](https://cernbox.cern.ch/s/4xYL99cd7zNe0VK). Note that this is simplified data (pt > 1 GeV truth cut) and a single event has been broken up into 32 sectors.

In [2]:
dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all"
        ],
        stop=28000,
        # If you run into memory issues, reduce this
        batch_size=10,
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all"
        ],
        start=28000,
    ),
)

### 2. Defining the module

In [3]:
class SillyEC(nn.Module, HyperparametersMixin):
    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 12,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.node_in_dim = node_in_dim
        self.edge_in_dim = edge_in_dim
        self.hidden_dim = hidden_dim

        self.fcnn = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, data):
        w = self.fcnn(data.edge_attr).squeeze()
        return {"W": w}

In [5]:
model = SillyEC(node_in_dim=6, edge_in_dim=4, hidden_dim=128)

### 2. Setting up the loss functions and the lightning module

In [6]:
lmodel = ECModule(
    model=model,
    loss_fct=EdgeWeightFocalLoss(alpha=0.3),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
)

[36m[18:26:19] DEBUG: Got obj of type <class '__main__.SillyEC'>, assuming I have to save hyperparameters[0m
[36m[18:26:19] DEBUG: Saving hyperperameters {'class_path': '__main__.SillyEC', 'init_args': {'node_in_dim': 6, 'edge_in_dim': 4, 'hidden_dim': 128}}[0m
[36m[18:26:19] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.EdgeWeightFocalLoss'>, assuming I have to save hyperparameters[0m
[36m[18:26:19] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.EdgeWeightFocalLoss', 'init_args': {'pt_thld': 0.0, 'alpha': 0.3, 'gamma': 2.0, 'pos_weight': tensor([1.])}}[0m


### 3. Starting training

In [7]:
trainer = Trainer(max_epochs=1, accelerator="cpu", log_every_n_steps=1)
trainer.fit(model=lmodel, datamodule=dm)

  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(
Missing logger folder: /home/kl5675/Documents/23/git_sync/hpo/slurm/lightning_logs
[32m[18:26:24] INFO: DataLoader will load 28000 graphs (out of 28800 available).[0m
[36m[18:26:24] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21974_s9.pt[0m
[32m[18:26:24] INFO: DataLoader will load 800 graphs (out of 28800 available).[0m
[36m[18:26:24] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21975_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part1_pt1/all/data21999_s9.pt[0m

  | Name     | Type                | Params
-

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  storage = elem.storage()._new_shared(numel)


  storage = elem.storage()._new_shared(numel)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  storage = elem.storage()._new_shared(numel)


Validation: 0it [00:00, ?it/s]

  storage = elem.storage()._new_shared(numel)


`Trainer.fit` stopped: `max_epochs=1` reached.


## With graphs built on-the-fly from point clouds

Step 1: Configure data module to load point clouds (rather than graphs).
Step 2: Add `MLGraphConstructionFromChkpt` as preproc.

In [None]:
lmodel = ECModule(
    model=model,
    loss_fct=EdgeWeightFocalLoss(alpha=0.3),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    preproc=MLGraphConstructionFromChkpt(
        ml_class_name="gnn_tracking.models.graph_construction.GraphConstructionFCNN",
        ml_chkpt_path="/path/to/your/checkpoint",
    ),
)

Instead of `MLGraphConstructionFromChkpt` you can also take a look at `MLGraphConstruction` that simply takes a model (that you can instantiate in any way).