In [1]:
import os
os.environ['TRKXINPUTDIR']="/global/cfs/cdirs/m3443/data/trackml-kaggle/train_10evts"
os.environ['TRKXOUTPUTDIR']= "/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/outtest" 

In [2]:
import pkg_resources
import yaml
import pprint
import random
random.seed(1234)
import numpy as np
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import tqdm
from os import listdir
from os.path import isfile, join
import matplotlib.cm as cm
import sys
# %matplotlib widget

sys.path.append('/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/')

# 3rd party
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from trackml.dataset import load_event
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


# local import
from exatrkx import config_dict # for accessing predefined configuration files
from exatrkx import outdir_dict # for accessing predefined output directories
from exatrkx.src import utils_dir
from exatrkx.src import utils_robust
from exatrkx.src.processing.cell_direction_utils.utils import get_one_event,load_detector
from utils_robust import *
from inference_fn import *


# for preprocessing
from exatrkx import FeatureStore
from exatrkx.src import utils_torch

# for embedding
from exatrkx import LayerlessEmbedding
from exatrkx.src import utils_torch
from torch_cluster import radius_graph
from utils_torch import build_edges
from embedding.embedding_base import *

# for filtering
from exatrkx import VanillaFilter

# for GNN
import tensorflow as tf
from graph_nets import utils_tf
from exatrkx import SegmentClassifier
import sonnet as snt

# for labeling
from exatrkx.scripts.tracks_from_gnn import prepare as prepare_labeling
from exatrkx.scripts.tracks_from_gnn import clustering as dbscan_clustering

# track efficiency
from trackml.score import _analyze_tracks
from exatrkx.scripts.eval_reco_trkx import make_cmp_plot, pt_configs, eta_configs
from functools import partial

In [3]:
embed_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt'
filter_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt'
gnn_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/gnn'
plots_dir = '/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/plots/run1' # needs to change...
detector_dir = "/global/cfs/cdirs/m3443/data/trackml-kaggle/detectors.csv"
ckpt_idx = -1 # which GNN checkpoint to load
dbscan_epsilon, dbscan_minsamples = 0.25, 2 # hyperparameters for DBScan
min_hits = 5 # minimum number of hits associated with a particle to define "reconstructable particles"
frac_reco_matched, frac_truth_matched = 0.5, 0.5 # parameters for track matching

In [4]:
def get_cell_data(event_file, detector_dir):
    hits, particles, truth = load_event(event_file, parts=['hits', 'particles', 'truth'])

    r = np.sqrt(hits.x**2 + hits.y**2)
    phi = np.arctan2(hits.y, hits.x)
    hits = hits.assign(r=r, phi=phi)
    hits = hits.merge(truth, on='hit_id')
    hits = hits[hits['particle_id'] != 0]

    detector_orig, detector_proc = load_detector(detector_dir)
    angles = get_one_event(event_file, detector_orig, detector_proc, remove_endcaps = False, remove_noise = False, pt_cut = 0)
    hits = hits.merge(angles, on='hit_id')

    cell_features = ['cell_count', 'cell_val', 'leta', 'lphi', 'lx', 'ly', 'lz', 'geta', 'gphi']
    feature_scale = np.array([1000, np.pi, 1000])
    hid = hits['hit_id'].to_numpy()
    x = hits[['r', 'phi', 'z']].to_numpy() / feature_scale
    cell_data = hits[cell_features].to_numpy()
    
    return cell_data,hid,x

In [5]:
def gnn_metrics(data_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    data = torch.load(data_path)

    # ### Evaluating Embedding
    e_ckpt = torch.load(embed_ckpt_dir, map_location=device)
    e_config = e_ckpt['hyper_parameters']
    e_config['clustering'] = 'build_edges'
    e_config['knn_val'] = 500
    e_config['r_val'] = 1.7

    e_model = LayerlessEmbedding(e_config).to(device)
    e_model.load_state_dict(e_ckpt["state_dict"])
    e_model.eval()

    # Map each hit to the embedding space, return the embeded parameters for each hit
    with torch.no_grad():
        spatial = e_model(torch.cat([data.cell_data, data.x], axis=-1)) #.to(device)

    e_spatial = utils_torch.build_edges(spatial.to(device), e_model.hparams['r_val'], e_model.hparams['knn_val'])


    # Removing edges that point from outer region to inner region, which almost removes half of edges.
    R_dist = torch.sqrt(data.x[:,0]**2 + data.x[:,2]**2) # distance away from origin...
    e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])]
    
    cluster_true = len(data.layerless_true_edges[0])
    cluster_true_positive = y_cluster.sum()
    cluster_positive = len(e_spatial_n[0])
    purity = cluster_true_positive/cluster_positive
    eff = cluster_true_positive/cluster_true
    
    print("-----------")
    print(f"cluster true = {cluster_true}")
    print(f"cluste true positive = {cluster_true_positive}")
    print(f"cluster positive = {cluster_positive}")
    print(f"purity = {purity}")
    print(f"efficiency = {eff}")
    
    return purity, eff

In [7]:
event_file = "/global/cfs/cdirs/m3443/data/trackml-kaggle/train_10evts/event000001000"
cell_data, hid, x = get_cell_data(event_file, detector_dir)
print("hid:", hid.shape)
print("x:", x.shape)
print("cell data:", cell_data.shape)

print("start track finding")
start_time = time.time()
tracks = gnn_track_finding(hid, x, cell_data)
end_time = time.time()
print(tracks[0])
print(tracks[1])
print("total {:.2} seconds".format(end_time - start_time))

Loading detector...
Detector loaded.
Loading event /global/cfs/cdirs/m3443/data/trackml-kaggle/train_10evts/event000001000 with a 0 pT cut


TypeError: Can only merge Series or DataFrame objects, a <class 'numpy.ndarray'> was passed