# Building graphs with Metric Learning

This notebook shows how to build graphs using a metric learning strategy. For this, every hit is independently projected to a latent space using a fully connected neural network. The network is trained to put hits from the same particle close to each other and hits from different particles far from each other. An initial graph can then be constructed by connecting hits that are close in this space.
This strategy has been adapted by ExaTrkx, see for example [section 5.2 here.](https://link.springer.com/10.1140/epjc/s10052-021-09675-8)

This notebook also serves as an introduction to the new pytorch lightning-based framework.

In [1]:
from functools import partial

import torch

from gnn_tracking.training.ml import MLModule
from gnn_tracking.models.graph_construction import GraphConstructionFCNN
from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss
from pytorch_lightning import Trainer
from gnn_tracking.utils.loading import TrackingDataModule

from torch_geometric.data import Data
from torch import nn
from pytorch_lightning.core.mixins import HyperparametersMixin

## Step 1: Configuring the data

The configuration for train/val/test data and its dataloader is held in the `TrackingDataModule` (subclass of `LightningDataModule`).

In [2]:
dm = TrackingDataModule(
    train=dict(
        dirs=["/Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/"],
        stop=5,
    ),
    val=dict(
        dirs=["/Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/"],
        start=5,
        stop=10,
    ),
    # could also configure a 'test' set here
)

Other keys allow to configure the loaders (batch size, number of workers, etc.). See the docstring of `TrackingDataModule` for details.

### Details (for understanding)

Note that all of the following will be done implicitly by the `Trainer` and you won't have to worry about it. But if you want to inspect the data, you can do so.

When calling the `setup` method, the `LightningDataModule` initializes instances of `TrackingDataset` (`torch_geometric.Dataset`) for each of these. We can get the corresponding dataloaders by calling `dm.train_dataloader()` and analog for validation and test.

Example:

In [3]:
# This is called by the Trainer automatically and sets up the datasets
dm.setup(stage="fit")  # 'fit' combines 'train' and 'val'
# Now the datasets are available:
dm.datasets

[32m[10:06:48] INFO: DataLoader will load 5 graphs (out of 90 available).[0m
[36m[10:06:48] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21025_s0.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21053_s0.pt[0m
[32m[10:06:48] INFO: DataLoader will load 5 graphs (out of 90 available).[0m
[36m[10:06:48] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21058_s0.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21094_s0.pt[0m


{'train': TrackingDataset(5), 'val': TrackingDataset(5)}

For example, we can inspect the first element of the training dataset:

In [4]:
data = dm.datasets["train"][0]

To get the corresponding dataloaders, use one of the methods (but again, you probalby won't need to):

In [5]:
dm.train_dataloader(), dm.val_dataloader()

(<torch_geometric.loader.dataloader.DataLoader at 0x28c37ed90>,
 <torch_geometric.loader.dataloader.DataLoader at 0x17772abd0>)

## Step 2: Configuring a model

We write a normal `torch.nn.Module`. The easiest way is to import one of the modules that we have already written in the `gnn_tracking` librar.

In [6]:
model = GraphConstructionFCNN(in_dim=14, out_dim=8, depth=5, hidden_dim=64)

However, you can also write your own. Here is a very simple one:

In [7]:
class DemoGraphConstructionModel(nn.Module, HyperparametersMixin):
    def __init__(
        self,
        in_dim: int,
        hidden_dim: int,
        out_dim: int,
        depth: int = 5,
    ):
        super().__init__()
        # This is made available by the HyperparametersMixin
        # all of our hyperparameters from the __init__ arguments
        # are saved to self.hparams (but we don't need this in this
        # example)
        self.save_hyperparameters()
        assert depth > 2
        _layers = [
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
        ]
        for _ in range(depth - 2):
            _layers.append(nn.Linear(hidden_dim, hidden_dim))
            _layers.append(nn.ReLU())
        _layers.append(nn.Linear(hidden_dim, out_dim))
        self._model = nn.Sequential(*_layers)

    def forward(self, data: Data):
        # Our trainer class will expect us to return a dictionary, where
        # the key H has the transformed latent space.
        return {"H": self._model(data.x)}

In [8]:
# model = DemoGraphConstructionModel(in_dim=14, out_dim=8, hidden_dim=64)

If you are familiar with normal pytorch, there was only few differences:

1. We inherit from `HyperparamsMixin`
2. We call `self.save_hyperparameters()`

### Details (for understanding)

We saved all hyperparameters:

In [9]:
model.hparams

"beta":       0.4
"depth":      5
"hidden_dim": 64
"in_dim":     14
"out_dim":    8

Note how `depth=5` was saved despite not being specified explicitly (it was recognized as a default parameter).

As always, you can simply evaluate the `model` on a piece of data:

In [10]:
out = model(data)

## Step 3: Configuring loss functions, metrics and the lightning module

The pytorch model is bundled together with a set of loss functions (just one here), that we backpropagate from in the training step, and a set of metrics. Together, these components make up the `LightningModule` that we pass to the pytorch lightning `Trainer` for training.

If you were familiar with our previous `TCNTrainer` training class, this `MLModule` now fulfills (almost) the exact same role.

In [11]:
lmodel = MLModule(
    model=model,
    loss_fct=GraphConstructionHingeEmbeddingLoss(max_num_neighbors=10),
    lw_repulsive=0.5,  # loss weight, see below
    optimizer=partial(torch.optim.Adam, lr=1e-4),
)

[36m[10:06:49] DEBUG: Got obj of type <class 'gnn_tracking.models.graph_construction.GraphConstructionFCNN'>, assuming I have to save hyperparameters[0m
[36m[10:06:49] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN', 'init_args': {'in_dim': 14, 'hidden_dim': 64, 'out_dim': 8, 'depth': 5, 'beta': 0.4}}[0m
[36m[10:06:49] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss'>, assuming I have to save hyperparameters[0m
[36m[10:06:49] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss', 'init_args': {'r_emb': 0.002, 'max_num_neighbors': 10, 'attr_pt_thld': 0.9, 'p_attr': 1, 'p_rep': 1}}[0m


### Details (for understanding)

Again, all hyperparameters are accessible (even the ones that weren't explicitly specified but only set by default):

In [12]:
lmodel.hparams

"loss_fct":     {'class_path': 'gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss', 'init_args': {'r_emb': 0.002, 'max_num_neighbors': 10, 'attr_pt_thld': 0.9, 'p_attr': 1, 'p_rep': 1}}
"lw_repulsive": 0.5
"model":        {'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN', 'init_args': {'in_dim': 14, 'hidden_dim': 64, 'out_dim': 8, 'depth': 5, 'beta': 0.4}}

As you can see, any _objects_ that were passed to the model are also saved to the hyperparameters in a way that we can bring them back.

The loss function takes output from the model and the data and returns two separate losses:

In [13]:
loss_fct = GraphConstructionHingeEmbeddingLoss()
loss_fct(
    x=out["H"],
    particle_id=data.particle_id,
    batch=data.batch,
    edge_index=data.edge_index,
    pt=data.pt,
)

{'attractive': tensor(0.0221, grad_fn=<DivBackward0>),
 'repulsive': tensor(0.0020, grad_fn=<SumBackward0>)}

Both parts of the loss functions are combined with the loss weight we have configured above (weight of 1 for attractive, weight of 0.5 for repulsive). All of this is done in `MLModule.get_losses` (returning the total loss and a dictionary of the individual losses):

In [14]:
lmodel.get_losses(out, data)

(tensor(0.0270, grad_fn=<AddBackward0>),
 {'attractive': 0.02667856030166149,
  'repulsive': 0.0006350235780701041,
  'attractive_weighted': 0.02667856030166149,
  'repulsive_weighted': 0.00031751178903505206,
  'total': 0.026996072381734848})

## Step 4: Training

In [15]:
trainer = Trainer(max_epochs=1, accelerator="cpu", log_every_n_steps=1)
trainer.fit(model=lmodel, datamodule=dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
[32m[10:06:53] INFO: DataLoader will load 5 graphs (out of 90 available).[0m
[36m[10:06:53] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21025_s0.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21053_s0.pt[0m
[32m[10:06:53] INFO: DataLoader will load 5 graphs (out of 90 available).[0m
[36m[10:06:53] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21058_s0.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21094_s0.pt[0m

  | Name     | Type                                | Params
-----------------------------------------------------------------
0 | model    | GraphConstructionFCNN               | 17.8 K
1 | loss_fct | GraphConstructionHingeEmbeddingLoss | 0     
-----------------------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


[3m            Validation epoch=1             [0m
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃[1m [0m[1mMetric             [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1m  Error[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│ attractive          │ 0.02582 │ 0.00079 │
│ attractive_weighted │ 0.02582 │ 0.00079 │
│ repulsive           │ 0.00068 │ 0.00002 │
│ repulsive_weighted  │ 0.00034 │ 0.00001 │
│ total               │ 0.02616 │ 0.00079 │
└─────────────────────┴─────────┴─────────┘



### If there are issues with the progress bar

The lightning progress bar can be finnicky when combined with printing the validation results to the command line, especially when running from a Jupyter notebook. Here's a couple of things to try:

* set `enable_progress_bar=False` in the `Trainer` initialization to disable the progress bar
* use `callbacks=[pytorch_lightning.callbacks.RichProgressBar(leave=True)]` in the `Trainer` initialization (this is a prettier progress bar, anyway). I
* use `callbacks=[gnn_tracking.utils.lightning.SimpleTqdmProgressBar(leave=True)]`
* set `lmodel.print_validation_results=False` to disable printing the validation results to the command line

## Restoring a pre-trained model

Take a look at the `lightning_logs` directory:

In [None]:
! ls lightning_logs

In [None]:
! ls lightning_logs/version_0/checkpoints

Navigate to one of the versions and take a look at the `hparams.yaml` file. It should contain exactly the hyperparameters from the run.


In [None]:
! cat lightning_logs/version_0/hparams.yaml

We can bring back the trained model by loading one of the checkpoints:

In [None]:
restored_model = MLModule.load_from_checkpoint(
    "lightning_logs/version_0/checkpoints/epoch=0-step=5.ckpt"
)

Note how we didn't have to specify any the hyperparameters again.

However, we can easily change some of them by adding them as additional keyword arguments.

In [None]:
restored_model_modified = MLModule.load_from_checkpoint(
    "lightning_logs/version_0/checkpoints/epoch=0-step=5.ckpt",
    lw_repulsive=0.1,
    loss_fct=GraphConstructionHingeEmbeddingLoss(max_num_neighbors=5),
)

Note that you cannot modify the model architecture however (but you could in principle change the `beta` parameter of the residual connections).

## Running all of this from the command line

All of the following can be achieved by running the following command:

```bash
python3 gnn_tracking/trainers/run.py fit --model configs/model.yml --data configs/data.yml  --trainer.accelerator cpu --trainer.accelerator cpu
```

with the data config file

```yaml
train:
  dirs:
    - /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/
  stop: 5
test:
  dirs:
    - /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/
  star: 10
  stop: 15
val:
  dirs:
    - /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/
  start: 5
  stop: 10
```

and model config file:

```
class_path: gnn_tracking.training.ml.MLModule
init_args:
  model:
    class_path: gnn_tracking.models.graph_construction.GraphConstructionFCNN
    init_args:
      in_dim: 14
      out_dim: 8
      hidden_dim: 512
      depth: 5
  lw_repulsive: 0.5
  loss_fct:
    class_path: gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss
    init_args: {}
  optimizer:
    class_path: torch.optim.Adam
    init_args:
      lr: 0.0001
```

To quickly override one of the options, you can simply add them to the command line, e.g., `--model.init_args.lw_repulsive=0.1` or `--model.model.init_args.depth=6`.