# Edge classification

In [1]:
from pytorch_lightning import Trainer
from torch import nn
from pytorch_lightning.core.mixins import HyperparametersMixin
import torch
from functools import partial

from torch_geometric.data import Data
from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt

from gnn_tracking.metrics.losses import EdgeWeightFocalLoss
from gnn_tracking.training.ec import ECModule

from gnn_tracking.utils.loading import TrackingDataModule

## From on-disk graphs

### 1. Setting up the data

We can either directly load graphs (from disk), or we load point clouds and build edges on the fly using the module from `009_build_graphs_ml.ipynb`.

We'll first do the former (for simplicity), using the simplified data from [here](https://cernbox.cern.ch/files/link/public/YQxujEYrVFFpylN?tiles-size=1&items-per-page=100&view-mode=resource-table&sort-dir=desc).

In [2]:
dm = TrackingDataModule(
    train=dict(
        dirs=["/Users/fuchur/tmp/truth_cut_graphs_for_gsoc_ec_challenge/batch_1_0/"],
        stop=5,
    ),
    val=dict(
        dirs=["/Users/fuchur/tmp/truth_cut_graphs_for_gsoc_ec_challenge/batch_1_0/"],
        start=5,
        stop=10,
    ),
    # could also configure a 'test' set here
)

### 2. Defining the module

In [3]:
class SillyEC(nn.Module, HyperparametersMixin):
    def __init__(
        self,
        node_in_dim: int,
        edge_in_dim: int,
        hidden_dim: int = 12,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.node_in_dim = node_in_dim
        self.edge_in_dim = edge_in_dim
        self.hidden_dim = hidden_dim

        self.fcnn = nn.Sequential(
            nn.Linear(edge_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, data):
        w = self.fcnn(data.edge_attr).squeeze()
        return {"W": w}

In [4]:
model = SillyEC(node_in_dim=6, edge_in_dim=4)

### 2. Setting up the loss functions and the lightning Module

Unfortunately the GSoC challenge data didn't contain the pT information of each hit. However, this is expected by the trainer. So we cheat a little bit and just add a constant value when loading the data:

In [9]:
class AddFakePt(HyperparametersMixin):
    def __call__(self, data: Data) -> Data:
        if not hasattr(data, "pt"):
            data.pt = torch.full_like(data.x, 1.0)
        return data

In [10]:
lmodel = ECModule(
    model=model,
    loss_fct=EdgeWeightFocalLoss(alpha=0.3),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    preproc=AddFakePt(),
)

[36m[15:07:19] DEBUG: Got obj of type <class '__main__.SillyEC'>, assuming I have to save hyperparameters[0m
[36m[15:07:19] DEBUG: Saving hyperperameters {'class_path': '__main__.SillyEC', 'init_args': {'node_in_dim': 6, 'edge_in_dim': 4, 'hidden_dim': 12}}[0m
[36m[15:07:19] DEBUG: Got obj of type <class '__main__.AddFakePt'>, assuming I have to save hyperparameters[0m
[36m[15:07:19] DEBUG: Saving hyperperameters {'class_path': '__main__.AddFakePt', 'init_args': {}}[0m
[36m[15:07:19] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.EdgeWeightFocalLoss'>, assuming I have to save hyperparameters[0m
[36m[15:07:19] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.EdgeWeightFocalLoss', 'init_args': {'pt_thld': 0.0, 'alpha': 0.3, 'gamma': 2.0, 'pos_weight': tensor([1.])}}[0m


### 3. Starting training

In [11]:
trainer = Trainer(max_epochs=1, accelerator="cpu", log_every_n_steps=1)
trainer.fit(model=lmodel, datamodule=dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
[32m[15:07:20] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[15:07:20] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc_ec_challenge/batch_1_0/data21000_s14.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc_ec_challenge/batch_1_0/data21006_s27.pt[0m
[32m[15:07:20] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[15:07:20] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc_ec_challenge/batch_1_0/data21008_s17.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc_ec_challenge/batch_1_0/data21012_s24.pt[0m

  | Name     | Type                | Params
-------------------------------------------------
0 | model    | SillyEC             | 229   
1 | loss_fct | EdgeWeightFocalLoss | 0     
-----------------------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


## With graphs built on-the-fly from point clouds

Step 1: Configure data module to load point clouds (rather than graphs).
Step 2: Add `MLGraphConstructionFromChkpt` as preproc.

In [8]:
lmodel = ECModule(
    model=model,
    loss_fct=EdgeWeightFocalLoss(alpha=0.3),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    preproc=MLGraphConstructionFromChkpt(
        ml_class_name="gnn_tracking.models.graph_construction.GraphConstructionFCNN",
        ml_chkpt_path="/path/to/your/checkpoint",
    ),
)

[36m[15:04:52] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m


AssertionError: 

Instead of `MLGraphConstructionFromChkpt` you can also take a look at `MLGraphConstruction` that simply takes a model (that you can instantiate in any way).