In [None]:
import os
os.environ['TRKXINPUTDIR']="/global/cfs/projectdirs/atlas/xju/heptrkx/trackml_inputs/train_all"
os.environ['TRKXOUTPUTDIR']= "/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/outtest" 

In [2]:
import pkg_resources
import yaml
import pprint
import random
import time
import pickle
random.seed(1234)
import numpy as np
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import tqdm
from tqdm import tqdm
import tqdm.notebook as tq
from pathlib import Path
from os import listdir
from os.path import isfile, join
import matplotlib.cm as cm
import sys
import warnings
warnings.filterwarnings('ignore')
from os import listdir
from os.path import isfile, join
import gc

# %matplotlib widget

sys.path.append('/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/')

# 3rd party
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from trackml.dataset import load_event
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


# local import
from exatrkx import config_dict # for accessing predefined configuration files
from exatrkx import outdir_dict # for accessing predefined output directories
from exatrkx.src import utils_dir
from exatrkx.src import utils_robust
from utils_robust import *


# for preprocessing
from exatrkx import FeatureStore
from exatrkx.src import utils_torch

# for embedding
from exatrkx import LayerlessEmbedding
from exatrkx.src import utils_torch
from torch_cluster import radius_graph
from utils_torch import build_edges
from embedding.embedding_base import *

# for filtering
from exatrkx import VanillaFilter

# for GNN
import tensorflow as tf
from graph_nets import utils_tf
from exatrkx import SegmentClassifier
import sonnet as snt

# for labeling
from exatrkx.scripts.tracks_from_gnn import prepare as prepare_labeling
from exatrkx.scripts.tracks_from_gnn import clustering as dbscan_clustering

# track efficiency
from trackml.score import _analyze_tracks
from exatrkx.scripts.eval_reco_trkx import make_cmp_plot, pt_configs, eta_configs
from functools import partial

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
# noise_keep = ["0","0.2", "0.4", "0.6", "0.8", "1"]
embed_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt'
filter_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt'
gnn_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/gnn'
plots_dir = '/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/plots/run1000' # needs to change...
ckpt_idx = -1 # which GNN checkpoint to load
dbscan_epsilon, dbscan_minsamples = 0.25, 2 # hyperparameters for DBScan
min_hits = 5 # minimum number of hits associated with a particle to define "reconstructable particles"
frac_reco_matched, frac_truth_matched = 0.5, 0.5 # parameters for track matching

In [5]:
# emb_ckpt = torch.load(embed_ckpt_dir, map_location='cpu')

# emb_ckpt['hyper_parameters']['clustering'] = 'build_edges'
# emb_ckpt['hyper_parameters']['knn_val'] = 500
# emb_ckpt['hyper_parameters']['r_val'] = 1.7
# emb_ckpt['hyper_parameters']

In [6]:
def get_data_np(mypath):
    onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))][:10]
    data_n = []
    for file in onlyfiles:
        data = torch.load(join(mypath,file))
        data_n.append(data)
    return data_n

In [7]:
def calc_evts(noise_dir,data_n):
    matched_idx = []
    peta = []
    par_pt = []
    total_times = []
    build_edges = []
    build_graphs = []
    predict_times = []
    filter_times = []
    doub_pur = []
    doub_eff = []
    
    for data in tq.tqdm(data_n):
        
        #############################################
        #                EMBEDDING                  #
        #############################################
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        e_ckpt = torch.load(embed_ckpt_dir, map_location=device)
        e_config = e_ckpt['hyper_parameters']
        e_config['clustering'] = 'build_edges'
        e_config['knn_val'] = 500
        e_config['r_val'] = 1.7
        e_model = LayerlessEmbedding(e_config).to(device)
        e_model.load_state_dict(e_ckpt["state_dict"])
        e_model.eval()
        
        with torch.no_grad():
            # had to move everything to device
            spatial = e_model(torch.cat([data.cell_data.to(device), data.x.to(device)], axis=-1))
            
        #total_start = time.time()
        
        #############################################
        #               BUILD EDGES                 #
        #############################################
        edges_start = time.time()
        e_spatial = utils_torch.build_edges(spatial.to(device), e_model.hparams['r_val'], e_model.hparams['knn_val'])
        edges_end = time.time()
        
        R_dist = torch.sqrt(data.x[:,0]**2 + data.x[:,2]**2) # distance away from origin...
        e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])]
        
        #############################################
        #              DOUBLET METRICS              #
        #############################################
        e_bidir = torch.cat([data.layerless_true_edges,torch.stack([data.layerless_true_edges[1],
                        data.layerless_true_edges[0]], axis=1).T], axis=-1)
        # did not have to convert e_spatail to tensor??
        e_spatial_n, y_cluster = graph_intersection(e_spatial, e_bidir)
        cluster_true = len(data.layerless_true_edges[0])
        cluster_true_positive = y_cluster.sum()
        cluster_positive = len(e_spatial_n[0])
        pur = cluster_true_positive/cluster_positive
        eff = cluster_true_positive/cluster_true      
        
        #############################################
        #                  FILTER                   #
        #############################################
        f_ckpt = torch.load(filter_ckpt_dir, map_location='cpu')
        f_config = f_ckpt['hyper_parameters']
        f_config['train_split'] = [0, 0, 1]
        f_config['filter_cut'] = 0.18

        f_model = VanillaFilter(f_config).to(device)
        f_model.load_state_dict(f_ckpt['state_dict'])
        f_model.eval()
        
        filter_start = time.time()
        emb = None # embedding information was not used in the filtering stage.
        chunks = 10
        output_list = []
        for j in range(chunks):
            subset_ind = torch.chunk(torch.arange(e_spatial.shape[1]), chunks)[j]
            with torch.no_grad():
                output = f_model(torch.cat([data.cell_data.to(device), data.x.to(device)], axis=-1), e_spatial[:, subset_ind], emb).squeeze()  #.to(device)
            output_list.append(output)
            del subset_ind
            del output
            gc.collect()
        output = torch.cat(output_list)
        output = torch.sigmoid(output)

        # The filtering network assigns a score to each edge. 
        # In the end, edges with socres > `filter_cut` are selected to construct graphs.
        # edge_list = e_spatial[:, output.to('cpu') > f_model.hparams['filter_cut']]
        edge_list = e_spatial[:, output > f_model.hparams['filter_cut']]
        filter_end = time.time()

        #############################################
        #               BUILD GRAPH                 #
        #############################################
        # ### Form a graph
        # Now moving TensorFlow for GNN inference.
        n_nodes = data.x.shape[0]
        n_edges = edge_list.shape[1]
        nodes = data.x.cpu().numpy().astype(np.float32)
        edges = np.zeros((n_edges, 1), dtype=np.float32)
        senders = edge_list[0].cpu()
        receivers = edge_list[1].cpu()

        input_datadict = {
            "n_node": n_nodes,
            "n_edge": n_edges,
            "nodes": nodes,
            "edges": edges,
            "senders": senders,
            "receivers": receivers,
            "globals": np.array([n_nodes], dtype=np.float32)
        }

        input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])

        num_processing_steps_tr = 8
        optimizer = snt.optimizers.Adam(0.001)
        model = SegmentClassifier()

        output_dir = gnn_ckpt_dir
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=output_dir, max_to_keep=10)
        status = checkpoint.restore(ckpt_manager.checkpoints[ckpt_idx]).expect_partial()

        # clean up GPU memory
        del e_spatial
        del e_model
        del f_model
        gc.collect()
        if device == 'cuda':
            torch.cuda.empty_cache()

        graph_start = time.time()
        outputs_gnn = model(input_graph, num_processing_steps_tr)
        output_graph = outputs_gnn[-1]
        graph_end = time.time()
        
        #############################################
        #             TRACK LABELLING               #
        #############################################
        predict_start = time.time()
        input_matrix = prepare_labeling(tf.squeeze(output_graph.edges).cpu().numpy(), senders, receivers, n_nodes)
        predict_tracks = dbscan_clustering(data.hid.cpu(), input_matrix, dbscan_epsilon, dbscan_minsamples)
        # trkx_groups = predict_track_df.groupby(['track_id'])
        # all_trk_ids = np.unique(predict_track_df.track_id)
        # n_trkxs = all_trk_ids.shape[0]
        # predict_tracks = [trkx_groups.get_group(all_trk_ids[idx])['hit_id'].to_numpy().tolist() for idx in range(n_trkxs)]
        predict_end = time.time()
          
        
        #############################################
        #            END-TO-END METRICS             #
        #############################################
        evt_path = data.event_file
        m_idx, pt, p_pt = track_eff(evt_path, predict_tracks,min_hits,frac_reco_matched, frac_truth_matched)
        
        #total_end = time.time()
        
        #############################################
        #               SAVE TO LIST                #
        #############################################
        
        #total_times.append(total_end-total_start)
        build_edges.append(edges_end-edges_start)
        predict_times.append(predict_end-predict_start)
        filter_times.append(filter_end-filter_start)
        build_graphs.append(graph_end-graph_start)
        
        matched_idx.append(m_idx)
        peta.append(pt)
        par_pt.append(p_pt)
        
        doub_pur.append(pur)
        doub_eff.append(eff)
        
    this_dict = {
        'matched_idx' : matched_idx,
        'peta' : peta,
        'par_pt' : par_pt,
        'doublet_purity' : doub_pur,
        'doublet_efficiency' : doub_eff,
        #'total_times' : total_times,
        'build_edges' : build_edges,
        'build_graphs' : build_graphs,
        'filter_times' : filter_times,
        'predict_times' : predict_times
    }
    
    return this_dict

In [8]:
def create_pickle(mis_dir,save_path,mis):
    data_n = get_data_np(mis_dir)
    print(f"------ Level {mis}------")
    dictn = calc_evts(mis_dir,data_n)
    print("--------------------")
    
    Path(save_path).mkdir(parents=True, exist_ok=True)
    save_path = save_path + f'list_{mis}.pickle'
    
    with open(save_path, 'wb') as handle:
        pickle.dump(dictn, handle)

In [9]:
def open_pickle(pickle_dir):
    with open(pickle_dir, 'rb') as handle:
        unpickler = pickle.Unpickler(handle)
        b = unpickler.load()
    return b

In [10]:
misl = [0.0025,0.005,0.0075,0.01,0.012,0.015,0.017,0.02,0.1,0.4,0.6,0.8,1]
vols = [7,8,9,12,13,14,16,17,18]

vol = 16
for mis in tqdm(misl):
    mis_dir = f'/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/misaligned/volumes_shifted/shift_x_{vol}/pre/{mis}/feature_store/'
    save_path = f'/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/misaligned/volumes_shifted/pickles/shift_x_{vol}/'
    create_pickle(mis_dir,save_path,mis)

  0%|          | 0/13 [00:00<?, ?it/s]

------ Level 0.0025------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Instructions for updating:
Use tf.identity instead.


  8%|▊         | 1/13 [03:05<37:05, 185.49s/it]


--------------------
------ Level 0.005------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 15%|█▌        | 2/13 [06:06<33:44, 184.08s/it]


--------------------
------ Level 0.0075------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 23%|██▎       | 3/13 [09:04<30:24, 182.46s/it]


--------------------
------ Level 0.01------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 31%|███       | 4/13 [12:02<27:08, 180.97s/it]


--------------------
------ Level 0.012------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 38%|███▊      | 5/13 [15:00<24:00, 180.10s/it]


--------------------
------ Level 0.015------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 46%|████▌     | 6/13 [17:56<20:51, 178.80s/it]


--------------------
------ Level 0.017------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 54%|█████▍    | 7/13 [20:53<17:50, 178.37s/it]


--------------------
------ Level 0.02------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 62%|██████▏   | 8/13 [23:51<14:51, 178.35s/it]


--------------------
------ Level 0.1------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 69%|██████▉   | 9/13 [26:53<11:56, 179.22s/it]


--------------------
------ Level 0.4------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 77%|███████▋  | 10/13 [29:55<09:00, 180.07s/it]


--------------------
------ Level 0.6------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 85%|████████▍ | 11/13 [33:01<06:03, 181.95s/it]


--------------------
------ Level 0.8------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

 92%|█████████▏| 12/13 [36:07<03:03, 183.10s/it]


--------------------
------ Level 1------


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

100%|██████████| 13/13 [39:10<00:00, 180.82s/it]


--------------------





In [11]:
# mis_dir = f'/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/misaligned/volumes_shifted/shift_x_12/0/'
# [f for f in listdir(mis_dir) if isfile(join(mis_dir, f))][:50]