# Testing of Hypothesis J

This is the hypothesis that a message passing benefits both global and local embedding quality (albeit not straightforwardly)

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
# System imports
import os
import sys
from pprint import pprint as pp
from time import time as tt
import inspect
import logging
import copy

# External imports
import matplotlib.pyplot as plt
import matplotlib.colors
from sklearn.decomposition import PCA
from sklearn.metrics import auc, f1_score 
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from mpl_toolkits.mplot3d import Axes3D
from itertools import permutations, combinations
from itertools import chain

from torch.nn import Linear
from torch_scatter import scatter, segment_csr, scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_cluster import knn_graph, radius_graph
from trackml.dataset import load_event
from trackml.score import _analyze_tracks, score_event

import yaml

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.checkpoint import checkpoint

# Limit CPU usage on Jupyter
os.environ['OMP_NUM_THREADS'] = '4'

# Pick up local packages
sys.path.append('..')

# Local imports
from lightning_modules.utils import evaluate_set_metrics, get_metrics, build_edges, graph_intersection
from lightning_modules.Filter.utils import edge_model_evaluation
from pytorch_lightning.loggers import WandbLogger
%matplotlib inline

# Get rid of RuntimeWarnings, gross
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [4]:
logging.basicConfig(level=logging.INFO)

In [63]:
logging.basicConfig(level=logging.WARNING)

# Train Connected Jet Model

## GNN Models

In [5]:
from lightning_modules.GNNEmbedding.Models.agnn import LocalAttentionNodeEmbedding

### Preload Model

In [5]:
ckpnt_path = "End2End-JetNodeEmbedding/jl96feyq/checkpoints/epoch=93.ckpt"
ckpnt = torch.load(ckpnt_path)

In [6]:
model = LocalAttentionNodeEmbedding(ckpnt["hyper_parameters"])

In [7]:
model.load_state_dict(ckpnt["state_dict"])

<All keys matched successfully>

In [8]:
trainer = Trainer(gpus=1, max_epochs=ckpnt["epoch"], num_sanity_val_steps=0, accumulate_grad_batches=1, resume_from_checkpoint=ckpnt_path)

GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [140]:
model.setup(stage="fit")

In [9]:
trainer.fit(model)

  warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.

  | Name              | Type           | Params
-----------------------------------------------------
0 | input_network     | Sequential     | 8 K   
1 | edge_network      | EdgeNetwork    | 18 K  
2 | node_network      | NodeNetwork    | 16 K  
3 | embedding_network | Sequential     | 13 K  
4 | multi_loss        | MultiNoiseLoss | 0     
INFO:lightning:
  | Name              | Type           | Params
-----------------------------------------------------
0 | input_network     | Sequential     | 8 K   
1 | edge_network      | EdgeNetwork    | 18 K  
2 | node_network      | NodeNetwork    | 16 K  
3 | embedding_network | Sequential     | 13 K  
4 | multi_loss        | MultiNoiseLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

### Train Model

In [6]:
with open("../lightning_modules/GNNEmbedding/train_jet_gnn.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [7]:
model = LocalAttentionNodeEmbedding(hparams)
wandb_logger = WandbLogger(project='End2End-ConnectedJetNodeEmbedding')
wandb_logger.watch(model)
wandb_logger.log_hyperparams({"model": type(model)})
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=wandb_logger, num_sanity_val_steps=0, accumulate_grad_batches=1)

[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores


### Training

In [8]:
trainer.fit(model)

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.

  | Name              | Type           | Params
-----------------------------------------------------
0 | input_network     | Sequential     | 9.0 K 
1 | edge_network      | EdgeNetwork    | 18.2 K
2 | node_network      | NodeNetwork    | 16.8 K
3 | embedding_network | Sequential     | 13.0 K
4 | multi_loss        | MultiNoiseLoss | 0     
-----------------------------------------------------
57.0 K    Trainable params
0         Non-trainable params
57.0 K    Total params
0.228     Total estimated model params size (MB)
INFO:lightning:
  | Name              | Type           | Params
-----------------------------------------------------
0 | input_network     | Sequential     | 9.0 K 
1 | edge_network      | EdgeNetwork    | 18.2 K
2 | node_network      | NodeNetwork    | 16.8 K
3 | embedding_network | Sequential     | 13.0 K
4 | multi_loss        | MultiNoiseLoss | 0     
------------------------------------------------

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

### Testing

In [9]:
from lightning_modules.utils import get_metrics, embedding_model_evaluation

In [148]:
results = trainer.test(ckpt_path=None)

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'edge_eff': tensor(0.9620, device='cuda:0'),
 'edge_pur': tensor(0.9563, device='cuda:0'),
 'loss': tensor(0.4085, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[ 7055,  7055,  7055, ...,  6760,  6760,  6760],
       [ 3856, 10637,  9096, ...,  5876,  5047,  6720]])}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'edge_eff': tensor(0.9572, device='cuda:0'),
 'edge_pur': tensor(0.9555, device='cuda:0'),
 'loss': tensor(0.3724, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[10276, 10276, 10276, ...,  2371,  2371,  2371],
       [ 7138, 11519,  4065, ...,  2419,   482, 11538]])}
--------------------------------------------------------------------------------
DATALOADER:2 TEST RESULTS
{'edge_eff': tensor

In [151]:
embedding_model_evaluation(model, trainer, "eff", 0.98)

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'edge_eff': tensor(0.9620, device='cuda:0'),
 'edge_pur': tensor(0.9563, device='cuda:0'),
 'loss': tensor(0.3497, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[ 7055,  7055,  7055, ...,  6760,  6760,  6760],
       [ 3856, 10637,  9096, ...,  5876,  5047,  6720]])}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'edge_eff': tensor(0.9572, device='cuda:0'),
 'edge_pur': tensor(0.9555, device='cuda:0'),
 'loss': tensor(0.3102, device='cuda:0'),
 'truth': array([ True,  True,  True, ...,  True,  True, False]),
 'truth_graph': array([[10276, 10276, 10276, ...,  2371,  2371,  2371],
       [ 7138, 11519,  4065, ...,  2419,   482, 11538]])}
--------------------------------------------------------------------------------
DATALOADER:2 TEST RESULTS
{'edge_eff': tensor

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'edge_eff': tensor(0.9620, device='cuda:0'),
 'edge_pur': tensor(0.9563, device='cuda:0'),
 'loss': tensor(0.3927, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False,  True,  True]),
 'truth_graph': array([[ 7055,  7055,  7055, ...,  6760,  6760,  6760],
       [ 3856, 10637,  9096, ...,  5876,  5047,  6720]])}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'edge_eff': tensor(0.9572, device='cuda:0'),
 'edge_pur': tensor(0.9555, device='cuda:0'),
 'loss': tensor(0.3599, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[10276, 10276, 10276, ...,  2371,  2371,  2371],
       [ 7138, 11519,  4065, ...,  2419,   482, 11538]])}
--------------------------------------------------------------------------------
DATALOADER:2 TEST RESULTS
{'edge_eff': tensor

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'edge_eff': tensor(0.9620, device='cuda:0'),
 'edge_pur': tensor(0.9563, device='cuda:0'),
 'loss': tensor(0.3573, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[ 7055,  7055,  7055, ...,  6760,  6760,  6760],
       [ 3856, 10637,  9096, ...,  5876,  5047,  6720]])}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'edge_eff': tensor(0.9572, device='cuda:0'),
 'edge_pur': tensor(0.9555, device='cuda:0'),
 'loss': tensor(0.3294, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[10276, 10276, 10276, ...,  2371,  2371,  2371],
       [ 7138, 11519,  4065, ...,  2419,   482, 11538]])}
--------------------------------------------------------------------------------
DATALOADER:2 TEST RESULTS
{'edge_eff': tensor

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'edge_eff': tensor(0.9620, device='cuda:0'),
 'edge_pur': tensor(0.9563, device='cuda:0'),
 'loss': tensor(0.3602, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[ 7055,  7055,  7055, ...,  6760,  6760,  6760],
       [ 3856, 10637,  9096, ...,  5876,  5047,  6720]])}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'edge_eff': tensor(0.9572, device='cuda:0'),
 'edge_pur': tensor(0.9555, device='cuda:0'),
 'loss': tensor(0.3322, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[10276, 10276, 10276, ...,  2371,  2371,  2371],
       [ 7138, 11519,  4065, ...,  2419,   482, 11538]])}
--------------------------------------------------------------------------------
DATALOADER:2 TEST RESULTS
{'edge_eff': tensor

Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'edge_eff': tensor(0.9620, device='cuda:0'),
 'edge_pur': tensor(0.9563, device='cuda:0'),
 'loss': tensor(0.3605, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[ 7055,  7055,  7055, ...,  6760,  6760,  6760],
       [ 3856, 10637,  9096, ...,  5876,  5047,  6720]])}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'edge_eff': tensor(0.9572, device='cuda:0'),
 'edge_pur': tensor(0.9555, device='cuda:0'),
 'loss': tensor(0.3323, device='cuda:0'),
 'truth': array([ True,  True,  True, ..., False, False, False]),
 'truth_graph': array([[10276, 10276, 10276, ...,  2371,  2371,  2371],
       [ 7138, 11519,  4065, ...,  2419,   482, 11538]])}
--------------------------------------------------------------------------------
DATALOADER:2 TEST RESULTS
{'edge_eff': tensor

((0.9800002405026307, 0.4984457079042242), 1.0684876976490727)