In [1]:
import os
# os.environ['TRKXINPUTDIR'] = '/global/cfs/cdirs/m3443/data/trackml-kaggle/train_all' # better change to your copy of the dataset.
# os.environ['TRKXOUTPUTDIR'] = '/global/cscratch1/sd/xju/heptrkx/iml2020/run200' # change to your own directory
os.environ['TRKXINPUTDIR']="/global/cfs/cdirs/m3443/data/trackml-kaggle/train_10evts"
os.environ['TRKXOUTPUTDIR']= "/global/cfs/projectdirs/m3443/usr/caditi97/iml2020/run1"

In [2]:
# system import
import pkg_resources
import yaml
import pprint
import random
random.seed(1234)
import numpy as np
import pandas as pd
import itertools
import matplotlib.pyplot as plt
# %matplotlib widget

# 3rd party
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from trackml.dataset import load_event
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


# local import
# from heptrkx.dataset import event as master
from exatrkx import config_dict # for accessing predefined configuration files
from exatrkx import outdir_dict # for accessing predefined output directories
from exatrkx.src import utils_dir


# for preprocessing
from exatrkx import FeatureStore
from exatrkx.src import utils_torch

# for embedding
from exatrkx import LayerlessEmbedding
from exatrkx.src import utils_torch

# for filtering
from exatrkx import VanillaFilter

# for GNN
import tensorflow as tf
from graph_nets import utils_tf
from exatrkx import SegmentClassifier
import sonnet as snt

# for labeling
from exatrkx.scripts.tracks_from_gnn import prepare as prepare_labeling
from exatrkx.scripts.tracks_from_gnn import clustering as dbscan_clustering

# track efficiency
from trackml.score import _analyze_tracks
from exatrkx.scripts.eval_reco_trkx import make_cmp_plot, pt_configs, eta_configs
from functools import partial

ImportError: cannot import name 'make_cmp_plot' from 'exatrkx.scripts.eval_reco_trkx' (/global/u2/c/caditi97/exatrkx-iml2020/exatrkx/scripts/eval_reco_trkx.py)

### Setup some hyperparameters and event

In [None]:
embed_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/embedding/checkpoints/epoch=10.ckpt'
filter_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/filtering/checkpoints/epoch=92.ckpt'
gnn_ckpt_dir = '/global/cfs/cdirs/m3443/data/lightning_models/gnn'
plots_dir = '/global/homes/c/caditi97/exatrkx-iml2020/exatrkx/src/plots/run1' # needs to change...
ckpt_idx = -1 # which GNN checkpoint to load
dbscan_epsilon, dbscan_minsamples = 0.25, 2 # hyperparameters for DBScan
min_hits = 5 # minimum number of hits associated with a particle to define "reconstructable particles"
frac_reco_matched, frac_truth_matched = 0.5, 0.5 # parameters for track matching

In [None]:
evtid = 1000
event_file = os.path.join(utils_dir.inputdir, 'event{:09}'.format(evtid))

### Preprocessing

In [None]:
action = 'build'

config_file = pkg_resources.resource_filename(
                    "exatrkx",
                    os.path.join('configs', config_dict[action]))
with open(config_file) as f:
    b_config = yaml.load(f, Loader=yaml.FullLoader)
    
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(b_config)

In [None]:
b_config['pt_min'] = 0
b_config['endcaps'] = True
b_config['n_workers'] = 1
b_config['n_files'] = 1

In [None]:
# this cell is only needed for the first run to prodcue the dataset
preprocess_dm = FeatureStore(b_config)
preprocess_dm.prepare_data()

### Read the preprocessed data

In [None]:
data = torch.load(os.path.join(utils_dir.feature_outdir, str(evtid)))
data

### Evaluating Embedding

In [None]:
e_ckpt = torch.load(embed_ckpt_dir, map_location='cpu')
e_config = e_ckpt['hyper_parameters']
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(e_config)

In [None]:
e_config = e_ckpt['hyper_parameters']
e_config['clustering'] = 'build_edges'
e_config['knn_val'] = 500
e_config['r_val'] = 1.7

Load the checkpoint and put the model in the evaluation state.

In [None]:
e_model = LayerlessEmbedding(e_config)
e_model.load_state_dict(e_ckpt["state_dict"])

In [None]:
e_model.eval()

Map each hit to the embedding space, return the embeded parameters for each hit

In [None]:
%%time
spatial = e_model(torch.cat([data.cell_data, data.x], axis=-1))

### From embeddeding space form doublets

`r_val = 1.7` and `knn_val = 500` are the hyperparameters to be studied.

* `r_val` defines the radius of the clustering method
* `knn_val` defines the number of maximum neighbors in the embedding space

In [None]:
%%time
e_spatial = utils_torch.build_edges(spatial, e_model.hparams['r_val'], e_model.hparams['knn_val'])

In [None]:
e_spatial = e_spatial.cpu().numpy()

Removing edges that point from outer region to inner region, which almost removes half of edges.

In [None]:
R_dist = torch.sqrt(data.x[:,0]**2 + data.x[:,2]**2) # distance away from origin...
e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])]

### Filtering


In [None]:
f_ckpt = torch.load(filter_ckpt_dir, map_location='cpu')
f_config = f_ckpt['hyper_parameters']
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(f_config)

In [None]:
f_config['train_split'] = [0, 0, 1]
f_config['filter_cut'] = 0.18

In [None]:
f_model = VanillaFilter(f_config)
# f_model = f_model.load_from_checkpoint(filter_ckpt_dir, hparams=f_config)
f_model.load_state_dict(f_ckpt['state_dict'])

In [None]:
f_model.eval()

In [None]:
%%time
emb = None # embedding information was not used in the filtering stage.
output = f_model(torch.cat([data.cell_data, data.x], axis=-1), e_spatial, emb).squeeze()

In [None]:
output = torch.sigmoid(output)

In [None]:
output.shape, e_spatial.shape

In [None]:
# this plot may need some time to load...
plt.hist(output.detach().numpy(), );

The filtering network assigns a score to each edge. In the end, edges with socres > `filter_cut` are selected to construct graphs.

In [None]:
edge_list = e_spatial[:, output > f_model.hparams['filter_cut']]

In [None]:
edge_list.shape

### Form a graph
Now moving TensorFlow for GNN inference.

In [None]:
n_nodes = data.x.shape[0]
n_edges = edge_list.shape[1]
nodes = data.x.numpy().astype(np.float32)
edges = np.zeros((n_edges, 1), dtype=np.float32)
senders = edge_list[0]
receivers = edge_list[1]

In [None]:
input_datadict = {
    "n_node": n_nodes,
    "n_edge": n_edges,
    "nodes": nodes,
    "edges": edges,
    "senders": senders,
    "receivers": receivers,
    "globals": np.array([n_nodes], dtype=np.float32)
}

In [None]:
input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])

### Apply GNN

In [None]:
num_processing_steps_tr = 8
optimizer = snt.optimizers.Adam(0.001)
model = SegmentClassifier()

output_dir = gnn_ckpt_dir
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=output_dir, max_to_keep=10)
status = checkpoint.restore(ckpt_manager.checkpoints[ckpt_idx])
print("Loaded {} checkpoint from {}".format(ckpt_idx, output_dir))

In [None]:
%%time
outputs_gnn = model(input_graph, num_processing_steps_tr)
output_graph = outputs_gnn[-1]

### Track labeling

In [None]:
input_matrix = prepare_labeling(tf.squeeze(output_graph.edges).numpy(), senders, receivers, n_nodes)

In [None]:
predict_tracks = dbscan_clustering(data.hid, input_matrix, dbscan_epsilon, dbscan_minsamples)

### Track Efficiency

In [None]:
hits, particles, truth = load_event(event_file, parts=['hits', 'particles', 'truth'])
hits = hits.merge(truth, on='hit_id', how='left')
hits = hits[hits.particle_id > 0] # remove noise hits
hits = hits.merge(particles, on='particle_id', how='left')
hits = hits[hits.nhits >= min_hits]
particles = particles[particles.nhits >= min_hits]
par_pt = np.sqrt(particles.px**2 + particles.py**2)
momentum = np.sqrt(particles.px**2 + particles.py**2 + particles.pz**2)
ptheta = np.arccos(particles.pz/momentum)
peta = -np.log(np.tan(0.5*ptheta))

In [None]:
tracks = _analyze_tracks(hits, predict_tracks)

In [None]:
purity_rec = np.true_divide(tracks['major_nhits'], tracks['nhits'])
purity_maj = np.true_divide(tracks['major_nhits'], tracks['major_particle_nhits'])
good_track = (frac_reco_matched < purity_rec) & (frac_truth_matched < purity_maj)

matched_pids = tracks[good_track].major_particle_id.values
score = tracks['major_weight'][good_track].sum()

n_recotable_trkx = particles.shape[0]
n_reco_trkx = tracks.shape[0]
n_good_recos = np.sum(good_track)
matched_idx = particles.particle_id.isin(matched_pids).values

In [None]:
print("Processed {} events from {}".format(evtid, utils_dir.inputdir))
print("Reconstructable tracks:         {}".format(n_recotable_trkx))
print("Reconstructed tracks:           {}".format(n_reco_trkx))
print("Reconstructable tracks Matched: {}".format(n_good_recos))
print("Tracking efficiency:            {:.4f}".format(n_good_recos/n_recotable_trkx))
print("Tracking purity:               {:.4f}".format(n_good_recos/n_reco_trkx))

In [None]:
make_cmp_plot_fn = partial(make_cmp_plot, xlegend="Matched", ylegend="Reconstructable",
                    ylabel="Events", ratio_label='Track efficiency')

In [None]:
make_cmp_plot_fn(par_pt[matched_idx], par_pt,
                 configs=pt_configs,
                 xlabel="pT [GeV]",
                 outname=os.path.join(plots_dir, "{}_pt".format(evtid)))

In [None]:
make_cmp_plot_fn(peta[matched_idx], peta,
                 configs=eta_configs,
                 xlabel=r"$\eta$",
                 outname=os.path.join(plots_dir, "{}_eta".format(evtid)))