In [1]:
import os
os.environ['TRKXINPUTDIR']="/global/cfs/projectdirs/atlas/xju/heptrkx/trackml_inputs/train_all"
os.environ['TRKXOUTPUTDIR']= "/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/outtest" 

In [2]:
import pkg_resources
import yaml
import pprint
import random
import time
import pickle
random.seed(1234)
import numpy as np
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import tqdm
from os import listdir
from os.path import isfile, join
import matplotlib.cm as cm
import sys
import warnings
warnings.filterwarnings('ignore')
from os import listdir
from os.path import isfile, join

# %matplotlib widget

sys.path.append('/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/')

# 3rd party
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from trackml.dataset import load_event
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


# local import
from exatrkx import config_dict # for accessing predefined configuration files
from exatrkx import outdir_dict # for accessing predefined output directories
from exatrkx.src import utils_dir
from exatrkx.src import utils_robust
from utils_robust import *


# for preprocessing
from exatrkx import FeatureStore
from exatrkx.src import utils_torch

# for embedding
from exatrkx import LayerlessEmbedding
from exatrkx.src import utils_torch
from torch_cluster import radius_graph
from utils_torch import build_edges
from embedding.embedding_base import *

# for filtering
from exatrkx import VanillaFilter

# for GNN
import tensorflow as tf
from graph_nets import utils_tf
from exatrkx import SegmentClassifier
import sonnet as snt

# for labeling
from exatrkx.scripts.tracks_from_gnn import prepare as prepare_labeling
from exatrkx.scripts.tracks_from_gnn import clustering as dbscan_clustering

# track efficiency
from trackml.score import _analyze_tracks
from exatrkx.scripts.eval_reco_trkx import make_cmp_plot, pt_configs, eta_configs
from functools import partial

In [3]:
noise_keep = ["0.2", "0.4", "0.6", "0.8", "1"]
embed_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt'
filter_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt'
gnn_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/gnn'
plots_dir = '/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/plots/run1000' # needs to change...
ckpt_idx = -1 # which GNN checkpoint to load
dbscan_epsilon, dbscan_minsamples = 0.25, 2 # hyperparameters for DBScan
min_hits = 5 # minimum number of hits associated with a particle to define "reconstructable particles"
frac_reco_matched, frac_truth_matched = 0.5, 0.5 # parameters for track matching

In [4]:
emb_ckpt = torch.load(embed_ckpt_dir, map_location='cpu')

emb_ckpt['hyper_parameters']['clustering'] = 'build_edges'
emb_ckpt['hyper_parameters']['knn_val'] = 500
emb_ckpt['hyper_parameters']['r_val'] = 1.7
emb_ckpt['hyper_parameters']

"adjacent":       False
"clustering":     build_edges
"emb_dim":        8
"emb_hidden":     512
"endcaps":        True
"factor":         0.3
"in_channels":    12
"input_dir":      /global/cscratch1/sd/danieltm/ExaTrkX/trackml/feature_store_endcaps
"knn":            20
"knn_val":        500
"layerless":      True
"layerwise":      False
"lr":             0.002
"margin":         1
"max_epochs":     100
"nb_layer":       6
"noise":          False
"output_dir":     global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/embedding_processed/0_pt_cut_endcaps
"overwrite":      True
"patience":       5
"project":        EmbeddingStudy
"pt_min":         0
"r_train":        1
"r_val":          1.7
"randomisation":  2
"regime":         ['rp', 'hnm', 'ci']
"train_split":    [900, 50, 50]
"wandb_save_dir": /global/cscratch1/sd/danieltm/ExaTrkX/wandb_data
"warmup":         500
"weight":         4

In [5]:
def get_data_np(mypath):
    onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))][:10]
    data_n = []
    for file in onlyfiles:
        data = torch.load(join(mypath,file))
        data_n.append(data)
    return data_n

In [6]:
def calc_evts(noise_dir,data_n):
    matched_idx = []
    peta = []
    par_pt = []
    build_times = []
    all_steps = []
    esp_times = []
    f_times = []
    trkeff_times = []
    
    for data in data_n:
        start1 = time.time()
        
        esp_strt = time.time()
        e_spatial = emb_eval(embed_ckpt_dir,data)
        esp_end = time.time()
        
        f_strt = time.time()
        output, f_model = filtering(filter_ckpt_dir,data,e_spatial)
        f_end = time.time()
        
        start = time.time()
        predict_tracks = build_graph(output, f_model,data, e_spatial, gnn_ckpt_dir,ckpt_idx,dbscan_epsilon, dbscan_minsamples)
        end = time.time()
        print("build time---- ")
        print(end - start)
        
        evt_path = data.event_file
        
        trkeff_strt = time.time()
        m_idx, pt, p_pt = track_eff(evt_path, predict_tracks,min_hits,frac_reco_matched, frac_truth_matched)
        trkeff_end = time.time()
        
        end1 = time.time()
        print("all steps time---- ")
        print(end1 - start1)
        
        build_times.append(end - start)
        esp_times.append(esp_end-esp_strt)
        f_times.append(f_end-f_strt)
        trkeff_times.append(trkeff_end-trkeff_strt)
        all_steps.append(end1 - start1)
        
        matched_idx.append(m_idx)
        peta.append(pt)
        par_pt.append(p_pt)
        
    this_dict = {
        'matched_idx' : matched_idx,
        'peta' : peta,
        'par_pt' : par_pt,
        'avg_build_time' : build_times,
        'all_steps' : all_steps,
        'esp_times' : esp_times,
        'f_times' : f_times,
        'trkeff_times' : trkeff_times
    }
    
    return this_dict

In [7]:
def create_pickle(n):
    noise_dir = f'/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/layerless_check/n{n}/feature_store'
    data_n = get_data_np(noise_dir)
    print(f"------Noise Level {n}------")
    start = time.time()
    dictn = calc_evts(noise_dir,data_n)
    end = time.time()
    print("total time---- ")
    print(end - start)
    dictn['total_time'] = end-start
    
    print("--------------------")

    with open(f'/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/layerless_check/lists_n{n}.pickle', 'wb') as handle:
        pickle.dump(dictn, handle)

In [8]:
for n in noise_keep:
    create_pickle(n)

------Noise Level 0.2------
 APPLYING GNN.....
Loaded -1 checkpoint from /global/cfs/cdirs/m3443/data/lightning_models/gnn
TRACK LABELLING.....
build time---- 
18.066550493240356
----------
Processed 0 events from /global/cfs/projectdirs/atlas/xju/heptrkx/trackml_inputs/train_all
Reconstructable tracks:         7095
Reconstructed tracks:           10735
Reconstructable tracks Matched: 6511
Tracking efficiency:            0.9177
Tracking purity:               0.6065
----------
all steps time---- 
97.65140390396118
 APPLYING GNN.....
Loaded -1 checkpoint from /global/cfs/cdirs/m3443/data/lightning_models/gnn
TRACK LABELLING.....
build time---- 
10.566899061203003
----------
Processed 0 events from /global/cfs/projectdirs/atlas/xju/heptrkx/trackml_inputs/train_all
Reconstructable tracks:         7647
Reconstructed tracks:           12112
Reconstructable tracks Matched: 6910
Tracking efficiency:            0.9036
Tracking purity:               0.5705
----------
all steps time---- 
95.96455

In [9]:
# with open('/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/layerless_check/lists_n0.pickle', 'rb') as handle:
#     unpickler = pickle.Unpickler(handle)
#     b = unpickler.load()
# b

{'matched_idx': [array([ True,  True,  True, ...,  True,  True,  True]),
  array([ True,  True,  True, ...,  True,  True,  True]),
  array([False,  True,  True, ..., False,  True, False]),
  array([ True,  True,  True, ...,  True,  True, False]),
  array([ True,  True,  True, ...,  True,  True, False]),
  array([ True,  True,  True, ...,  True,  True,  True]),
  array([False,  True,  True, ...,  True,  True,  True]),
  array([ True,  True,  True, ...,  True,  True,  True]),
  array([ True,  True,  True, ...,  True,  True,  True]),
  array([ True,  True,  True, ...,  True,  True, False])],
 'peta': [0       1.009692
  1      -2.667920
  2      -0.687386
  3      -0.556053
  4      -3.270726
            ...   
  9591    3.458632
  9593   -0.849981
  9595    0.280119
  9596    0.395756
  9597   -0.337825
  Length: 7095, dtype: float32,
  0        2.338404
  1        0.454788
  2       -0.220477
  3        0.398200
  4       -0.525143
             ...   
  10534    3.633767
  10536    2.32

In [None]:
# make_cmp_plot_fn = partial(make_cmp_plot, xlegend="Matched", ylegend="Reconstructable",
#                     ylabel="Events", ratio_label='Track efficiency')

In [None]:
# make_cmp_plot_fn(par_pt[matched_idx], par_pt,
#                  configs=pt_configs,
#                  xlabel="pT [GeV]",
#                  outname=os.path.join(plots_dir, "{}_pt".format('1000')))

In [None]:
# make_cmp_plot_fn(peta[matched_idx], peta,
#                  configs=eta_configs,
#                  xlabel=r"$\eta$",
#                  outname=os.path.join(plots_dir, "{}_eta".format('1000')))