# Track Parameter Regression with Pooling

This is a notebook exploring track regression with edge contraction pooling from the paper https://graphreason.github.io/papers/17.pdf

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import yaml

import numpy as np
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_sparse import coalesce
from torch_geometric.utils import softmax
from pytorch_lightning import Trainer

sys.path.append("../..")
sys.path.append("/global/homes/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/src/Pipelines/TrackML_Example/")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from lightning_modules.utils import make_mlp
from torch_scatter import scatter_mean

## Roadmap

We would like to try to regress at least one track parameter - let's pick pT

1. Set up toy graph, from TrackML 1pT cut
2. Run it through AGNN training
3. Add PyG pooling step
4. See what comes out!

Then see if the regular AGNN can do track parameter regression

5. Adapt base class to include pT regression L1 / L2 loss function
6. Try to train on node feature --> pT
7. Try to train on edge feature --> pT

Then compare pooling behaviour

8. Try to train on pooled node feature --> pT

### 1. Set up toy graph

In [21]:
from lightning_modules.GNN.Models.agnn import ResAGNN

In [22]:
with open("example_gnn.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [23]:
%%time
model = ResAGNN(hparams)
model.setup(stage="fit")

CPU times: user 186 ms, sys: 165 ms, total: 352 ms
Wall time: 1.68 s


### 2. Train AGNN

In [24]:
trainer = Trainer(gpus=1, max_epochs=10)
trainer.fit(model)

GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores
Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.

  | Name          | Type        | Params
----------------------------------------------
0 | input_network | Sequential  | 9.0 K 
1 | edge_network  | EdgeNetwork | 18.6 K
2 | node_network  | NodeNetwork | 17.2 K
----------------------------------------------
44.8 K    Trainable params
0         Non-trainable params
44.8 K    Total params
0.179     Total estimated model params size (MB)
INFO:lightning:
  | Name          | Type        | Params
----------------------------------------------
0 | input_network | Sequential  | 9.0 K 
1 | edge_network  | EdgeNetwork | 18.6 K
2 | node_network  | NodeNetwork | 17.2 K
----------------------------------------------
44.8 K    Trainable params
0         Non-trainable params
44.8 K    Total params
0.179     Total

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

### 3. Edge Pooling Step

In [25]:
from collections import namedtuple

In [28]:
class EdgePooling(torch.nn.Module):

    unpool_description = namedtuple(
        "UnpoolDescription",
        ["edge_index", "cluster", "batch", "new_edge_score"])

    def __init__(self, in_channels, edge_score_method=None, dropout=0,
                 add_to_edge_score=0.5):
        super(EdgePooling, self).__init__()
        self.in_channels = in_channels
        if edge_score_method is None:
            edge_score_method = self.compute_edge_score_softmax
        self.compute_edge_score = edge_score_method
        self.add_to_edge_score = add_to_edge_score
        self.dropout = dropout

        self.lin = torch.nn.Linear(2 * in_channels, 1)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()


    @staticmethod
    def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes):
        return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)


    @staticmethod
    def compute_edge_score_tanh(raw_edge_score, edge_index, num_nodes):
        return torch.tanh(raw_edge_score)


    @staticmethod
    def compute_edge_score_sigmoid(raw_edge_score, edge_index, num_nodes):
        return torch.sigmoid(raw_edge_score)


    def forward(self, x, edge_index, batch):
        e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
        e = self.lin(e).view(-1)
        e = F.dropout(e, p=self.dropout, training=self.training)
        e = self.compute_edge_score(e, edge_index, x.size(0))
        e = e + self.add_to_edge_score

        x, edge_index, batch, unpool_info = self.__merge_edges__(
            x, edge_index, batch, e)

        return x, edge_index, batch, unpool_info


    def __merge_edges__(self, x, edge_index, batch, edge_score):
        nodes_remaining = set(range(x.size(0)))

        cluster = torch.empty_like(batch, device=torch.device('cpu'))
        edge_argsort = torch.argsort(edge_score, descending=True)

        # Iterate through all edges, selecting it if it is not incident to
        # another already chosen edge.
        i = 0
        new_edge_indices = []
        edge_index_cpu = edge_index.cpu()
        for edge_idx in edge_argsort.tolist():
            source = edge_index_cpu[0, edge_idx].item()
            if source not in nodes_remaining:
                continue

            target = edge_index_cpu[1, edge_idx].item()
            if target not in nodes_remaining:
                continue

            new_edge_indices.append(edge_idx)

            cluster[source] = i
            nodes_remaining.remove(source)

            if source != target:
                cluster[target] = i
                nodes_remaining.remove(target)

            i += 1

        # The remaining nodes are simply kept.
        for node_idx in nodes_remaining:
            cluster[node_idx] = i
            i += 1
        cluster = cluster.to(x.device)

        # We compute the new features as an addition of the old ones.
        new_x = scatter_add(x, cluster, dim=0, dim_size=i)
        new_edge_score = edge_score[new_edge_indices]
        if len(nodes_remaining) > 0:
            remaining_score = x.new_ones(
                (new_x.size(0) - len(new_edge_indices), ))
            new_edge_score = torch.cat([new_edge_score, remaining_score])
        new_x = new_x * new_edge_score.view(-1, 1)

        N = new_x.size(0)
        new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)
        new_batch = new_batch.scatter_(0, cluster, batch)

        unpool_info = self.unpool_description(edge_index=edge_index,
                                              cluster=cluster, batch=batch,
                                              new_edge_score=new_edge_score)

        return new_x, new_edge_index, new_batch, unpool_info

    def unpool(self, x, unpool_info):
        new_x = x / unpool_info.new_edge_score.view(-1, 1)
        new_x = new_x[unpool_info.cluster]
        return new_x, unpool_info.edge_index, unpool_info.batch


    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, self.in_channels)

In [32]:
pool = EdgePooling(3, edge_score_method=EdgePooling.compute_edge_score_sigmoid, dropout=0,
                 add_to_edge_score=0.)

### 4. Test pool step

In [34]:
test_data = model.trainset[0]

In [35]:
test_data

Data(cell_data=[8766, 9], edge_index=[2, 49050], event_file="/global/cscratch1/sd/danieltm/ExaTrkX/trackml-codalab/train_all/event000021394", hid=[8766], layerless_true_edges=[2, 10529], layers=[8766], pid=[8766], pt=[8766], true_weights=[10529], weights=[10529], x=[8766, 3], y=[49050])

In [47]:
pooled_test = pool(test_data.x, test_data.edge_index, torch.zeros(len(test_data.x)).long())

In [49]:
pooled_x, pooled_edge_index, _, unpool_info = pooled_test

Compare nodes

In [53]:
test_data.x.shape

torch.Size([8766, 3])

In [54]:
pooled_x.shape

torch.Size([4505, 3])

In [55]:
test_data.edge_index.shape

torch.Size([2, 49050])

In [56]:
pooled_edge_index.shape

torch.Size([2, 32291])