In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch import nn, optim
import time, sys
from torch.utils.data import Dataset, DataLoader
import torch

# trackML challenge 

This notebook used to evaluate the model

The inputs are loaded from `/scratch/pitt/trackML/train_graphs` folder, this can be changed

Thge model is loaded from `/mnt/lustre/agrp/pitt/ML/trackML/sample_code_submission/GraphBuilder/data/`

If running on GPU, run the following cell:


In [None]:
use_cuda = torch.cuda.is_available()
print('Availability of CUDA:',use_cuda)
device = torch.device("cuda:1" if use_cuda else "cpu")
torch.cuda.set_device(1)
idevice = torch.cuda.current_device()
print('Will work on device number',idevice,', named: ',torch.cuda.get_device_name(idevice))
torch.backends.cudnn.benchmark = True
print('device = ',device)

Load the model from `GraphBuilder` folder, the model consist of two NN:
- First FC NN used to identify bad edges, the efficiency of good edges for thredhold of 0.2 is >98%
- Then GNN is applied on the full event

In [None]:
sys.path.append("/scratch/pitt/trackML/WeizmannAI")
from model.data_loader import trackDataLoader, collate_fn
from model.model import GNNmodel, PreTrainModel
from trainer import *
from utils import AnalyzeThreshold

Read files, in total we have 200 files, first 40 file for validation.

In [None]:
from glob import glob
train_files = glob('/scratch/pitt/trackML/graph_full_6var_250MeV_150mmz0_1eta/*')
test_files = train_files[:40]
test_dataset = trackDataLoader(test_files, 1) #test_file train_files
test_loader = DataLoader(test_dataset, batch_size = len(test_files), collate_fn=collate_fn)
print('Load',len(test_files),'files')

# DNN models

Load two models to clasify edges:

In [None]:
PATH = '/mnt/lustre/agrp/pitt/ML/trackML/sample_code_submission/GraphBuilder/data/'

#first model Deep FC NN for edge estimation 
model1 = PreTrainModel(hidden_features=32)
criterion1 = getattr(nn.functional, 'binary_cross_entropy_with_logits')
optimizer1 = optim.Adam(model1.parameters())

checkpoint1 = torch.load(PATH+'training_pretrain.pt')
model1.load_state_dict(checkpoint1['model_state_dict'])
model1.to(device)
optimizer1.load_state_dict(checkpoint1['optimizer_state_dict'])
cache1 = checkpoint1['cache']

#second model: GNN for edge clasification with 8 iteration over the neighbours
model2 = GNNmodel(edge_dim = 16, hidden_dim = 32, niter = 8)
criterion2 = getattr(nn.functional, 'binary_cross_entropy')
optimizer2 = optim.Adam(model2.parameters())

checkpoint2 = torch.load(PATH+'training_gnn_filtered2.pt')
model2.load_state_dict(checkpoint2['model_state_dict'])
model2.to(device)
optimizer2.load_state_dict(checkpoint2['optimizer_state_dict'])
cache2 = checkpoint2['cache']

### Check model performance

Evaluate the total weight obtained by combining both models

In [None]:
%%time
#test_dataset = trackDataLoader(test_files, 1) #test_file train_files
#test_loader = DataLoader(test_dataset, batch_size = len(test_file), collate_fn=collate_fn)
with torch.no_grad():
    inputs, test_target = next(iter(test_loader))
    X, Is = inputs
    inputs = get_inputs(X, Is, device)
    test_pred = torch.sigmoid(model1(inputs))
    
    #filter first training:
    mask_edges = (test_pred > 0.2).nonzero().squeeze().cpu()
    Is_filter = Is[mask_edges]
    e_masked = test_pred[mask_edges]
    test_target_masked = test_target[mask_edges]
    inputs = get_inputs(X, Is_filter, device)
    inputs.append(e_masked)
    
    #evaluate second round, and append to the first predictions
    test_pred_masked = model2(inputs)
    test_pred[mask_edges] = test_pred_masked
    
    test_target = torch.FloatTensor(test_target)
AnalyzeThreshold(test_pred.cpu(), test_target, log=True)


Evaluate prurity and acceptance for few working points:

In [None]:
cut_edge_weight = 0.9
print('total stats: %d good and %d fake edges'%(test_target.sum(),(test_target==0).sum()))
print('signal eff = %2.2f%% are above threshold'
      %(test_target[test_pred>cut_edge_weight].sum()/test_target[test_pred>-1].sum()*100))
print('signal purity = above the threshold %2.2f%% are truth edges'
      %(test_target[test_pred>cut_edge_weight].sum()/(test_pred>cut_edge_weight).sum()*100))


#### study of misclassified   edges: