## Preamble

In [88]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [89]:
import numpy as np

import matplotlib.pyplot as plt
import plotly
import pandas
from plotly.offline import iplot
import plotly.express as px
import plotly.graph_objects as go

import json
import torch

from torch.utils.data import Dataset, DataLoader

import sys
sys.path.append('..')
from dataset import PflowDataset, collate_graphs

In [90]:
%matplotlib notebook

## Settings

In [91]:
event_idx = 1
particle_lay = -1
show_interlayer_edges = False

## Load config

In [92]:
# from nbconvert.exporters.templateexporter import ROOT
# myFile = ROOT.TFile.Open("file.root")
config_path = 'cocoa_default_copy.json'
with open(config_path, 'r') as f:
    config = json.load(f)

#### update values

In [93]:
config['batchsize']     = 5
config['reduce_ds']     = config['batchsize']

## Load dataset

In [94]:
from  dataset import PflowDataset
ds = PflowDataset(config['path_to_val'],config,reduce_ds=config['reduce_ds']) #,entry_start=config['entry_start'])

100%|██████████| 48/48 [00:00<00:00, 331.59it/s]


dataset loaded
number of events: 5






In [95]:
# 1) single event
g, _ = ds[event_idx]

# 2) batch of events
# loader = DataLoader(ds, batch_size=config['batchsize'], num_workers=0, shuffle=True, collate_fn=collate_graphs, pin_memory=False)
# for bg in loader:
#     break
# g = bg

In [96]:
g

Graph(num_nodes={'cell': 266, 'global_node': 1, 'node': 269, 'particle': 11, 'pflow_particle': 11, 'pre_node': 269, 'topo': 7, 'track': 3},
      num_edges={('node', 'node_to_node', 'node'): 2309},
      metagraph=[('node', 'node', 'node_to_node')])

#### Grab transforms

In [97]:
transform = ds.transform_dicts

## Plotly 3D

### General

In [98]:
#colorby      = 'particle'
colorby      = 'class'
#colorby      = 'topocluster' #TODO
class_labels = ['charged','electron','muon','neutral','photon']
class_colors = {cl: co for cl, co in zip(class_labels,['magenta','green','red','grey','blue'])}

### Particle features

In [99]:
part_df = pandas.DataFrame()

for k in g.nodes['particle'].data.keys():
    part_df[k] = g.nodes['particle'].data[k].detach().numpy()
    
part_df['pt']  = transform['pt'].inverse(part_df['pt'])
part_df['eta'] = transform['eta'].inverse(part_df['eta'])
part_df['particle_class_label']  = [class_labels[c] for c in part_df['class']]
part_df['particle_class_color']  = [class_colors[c] for c in part_df['particle_class_label']]

### Cell features

In [100]:
cell_df = pandas.DataFrame()

### GEOMETRY
cell_df['cell_phi']          = g.nodes['cell'].data['phi'].detach().numpy()
cell_df['cell_eta']          = g.nodes['cell'].data['eta'].detach().numpy()
cell_df['cell_layer']        = g.nodes['cell'].data['calo_region'].detach().numpy().astype(int)

### KINEMATICS
cell_df['cell_e' ]           = g.nodes['cell'].data['e'].detach().numpy()
# cell_df['cell_zeta']         = g.nodes['node'].data['zeta'].detach().numpy()[cellnodes]

### TRANSFORM
cell_df['cell_eta']          = transform['eta'].inverse(cell_df['cell_eta'])

cell_df['cell_marker_size']  = np.clip(np.log(cell_df['cell_e']*1000),1,10)

### PARENT PROPERTIES
is_track = g.nodes['node'].data['is_track'].detach().numpy().astype(bool)
cell_df['cell_parent']       = g.nodes['node'].data['particle_idx'][~is_track].detach().numpy().astype(int)
cell_df['cell_class']        = [part_df['class'][p] if p >=0 else -1 for p in cell_df['cell_parent']]
cell_df['cell_class_label']  = [class_labels[c] for c in cell_df['cell_class']]
cell_df['cell_class_color']  = [class_colors[c] for c in cell_df['cell_class_label']]

### NOISE
cell_df.loc[cell_df['cell_parent'] < 0,'cell_class_label'] = 'noise'
cell_df.loc[cell_df['cell_parent'] < 0,'cell_class_color'] = 'black'

### Track features

In [101]:
track_df = pandas.DataFrame()

### INPUTS
track_inputs = [
                  'd0',
                  'z0',
                  'phi',
                  'cosphi',
                  'sinphi',
                  'pt',
                  'eta',
                  'eta_layer_0',
                  'eta_layer_1',
                  'eta_layer_2',
                  'eta_layer_3',
                  'eta_layer_4',
                  'eta_layer_5',
                  'sinphi_layer_0',
                  'sinphi_layer_1',
                  'sinphi_layer_2',
                  'sinphi_layer_3',
                  'sinphi_layer_4',
                  'sinphi_layer_5',
                  'cosphi_layer_0',
                  'cosphi_layer_1',
                  'cosphi_layer_2',
                  'cosphi_layer_3',
                  'cosphi_layer_4',
                  'cosphi_layer_5',
            ]

for var in track_inputs:
    track_df[var] = g.nodes['track'].data[var].detach().numpy()

for lay in range(6):
    track_df['phi_layer_{}'.format(lay)] = np.arctan2(track_df['sinphi_layer_{}'.format(lay)],track_df['cosphi_layer_{}'.format(lay)])

### PARENT PROPERTIES
track_df['track_parent']       = g.nodes['node'].data['particle_idx'][is_track].detach().numpy().astype(int)
track_df['track_class']        = [part_df['class'][p] if p >=0 else -1 for p in track_df['track_parent']]

track_df['track_class_label']  = [class_labels[c] for c in track_df['track_class']]
track_df['track_class_color']  = [class_colors[c] for c in track_df['track_class_label']]

### TRANSFORM
track_df['pt'] = transform['pt'].inverse(track_df['pt'])
track_df['eta'] = transform['eta'].inverse(track_df['eta'])
for lay in range(6):
    track_df['eta_layer_{}'.format(lay)] = transform['eta'].inverse(track_df['eta_layer_{}'.format(lay)])

### Cell trace

In [102]:
cell_trace=go.Scatter3d(x=cell_df['cell_phi'],
               y=cell_df['cell_eta'],
               z=cell_df['cell_layer'],
               mode='markers',
               name='cells',
               marker=dict(symbol='circle',
                             size=cell_df['cell_marker_size'],
                             color=cell_df['cell_class_color'],
                             #colorscale=plotly.colors.sequential.Viridis,
                             #colorscale=plotly.colors.qualitative.Alphabet,
                             line=dict(color='rgb(50,50,50)', width=0.5)
                             ),
               hovertemplate =
                '<b>%{text}</b><br>'+
                '<i>(eta,phi,lay)=(%{y:.2f},%{x:.2f},%{z:.2f})</i><br>',
               text = ['{}<br>E=<i>{:.4f} GeV'.format(cl,en) for cl,en in zip(cell_df['cell_class_label'],cell_df['cell_e'])]
               )

### Track trace

In [103]:
### TRACK PROJECTIONS ###
nlays = 6
proj_traces = []
for track_i, track_eta in enumerate(track_df['eta']):

    projlay = [ lay for lay in range(nlays)]
    projphi = [ track_df['phi_layer_{}'.format(lay)][track_i] for lay in projlay]
    projeta = [ track_df['eta_layer_{}'.format(lay)][track_i] for lay in projlay]
    projlay.insert(0,particle_lay)
    projphi.insert(0,track_df['phi'][track_i])
    projeta.insert(0,track_df['eta'][track_i])
    
    proj_traces.append(go.Scatter3d(x=np.array(projphi), y=np.array(projeta), z=np.array(projlay), marker=dict(symbol='circle',size=0.5), opacity=0.2, line=dict(color=track_df['track_class_color'][track_i],width=5.0)))

### Particle trace

In [104]:
particle_trace=go.Scatter3d(x=part_df['phi'],
               y=part_df['eta'],
               z=particle_lay*np.ones_like(part_df['eta']),
               mode='markers',
               name='particles',
               marker=dict(symbol='x',
                             size= 2, #part_df['particle_marker_size'],
                             color=part_df['particle_class_color'],
                             #colorscale=plotly.colors.sequential.Viridis,
                             #colorscale=plotly.colors.qualitative.Alphabet,
                             line=dict(color='rgb(50,50,50)', width=0.5)
                             ),
               hovertemplate =
                '<b>%{text}</b><br>'+
                '<i>(eta,phi)=(%{y:.2f},%{x:.2f})</i><br>',
               text = ['{}<br>pT=<i>{:.2f} GeV'.format(cl,pt) for cl,pt in zip(part_df['particle_class_label'],part_df['pt'])],
            )

### Edges trace

In [105]:
g.apply_edges(lambda edges: {'cell_eta_src': edges.src['eta_raw']}, etype='node_to_node')
g.apply_edges(lambda edges: {'cell_eta_dst': edges.dst['eta_raw']}, etype='node_to_node')
g.apply_edges(lambda edges: {'cell_phi_src': edges.src['phi']}, etype='node_to_node')
g.apply_edges(lambda edges: {'cell_phi_dst': edges.dst['phi']}, etype='node_to_node')
g.apply_edges(lambda edges: {'cell_lay_src': edges.src['calo_region']}, etype='node_to_node')
g.apply_edges(lambda edges: {'cell_lay_dst': edges.dst['calo_region']}, etype='node_to_node')
g.apply_edges(lambda edges: {'cell_to_cell': 1-edges.src['is_track']-edges.dst['is_track']}, etype='node_to_node')

#us, vs, eids = g.edges(etype=etype,form='all')

if show_interlayer_edges:
    c2c = torch.where(g.edges['node_to_node'].data['cell_to_cell']==1)
else:
    c2c = torch.where(torch.logical_and(g.edges['node_to_node'].data['cell_to_cell']==1,g.edges['node_to_node'].data['cell_lay_src']==g.edges['node_to_node'].data['cell_lay_dst']))
x1list = g.edges['node_to_node'].data['cell_phi_src'][c2c].detach().numpy()
x2list = g.edges['node_to_node'].data['cell_phi_dst'][c2c].detach().numpy()
y1list = g.edges['node_to_node'].data['cell_eta_src'][c2c].detach().numpy()
y2list = g.edges['node_to_node'].data['cell_eta_dst'][c2c].detach().numpy()
z1list = g.edges['node_to_node'].data['cell_lay_src'][c2c].detach().numpy()
z2list = g.edges['node_to_node'].data['cell_lay_dst'][c2c].detach().numpy()

Xe,Ye,Ze = [],[],[]

for eidx in range(len(x1list)):
    Xe += [x1list[eidx],x2list[eidx],None]
    Ye += [y1list[eidx],y2list[eidx],None]
    Ze += [z1list[eidx],z2list[eidx],None]

edge_trace = go.Scatter3d(x=Xe,
               y=Ye,
               z=Ze,
               mode='lines',
               line=dict(color='rgb(125,125,125,0.6)', width=0.5),
               hoverinfo='none'
               )

### Layout

In [106]:
def getaxis(var):
    vartitle = '\u03C6' if var=='x' else '\u03B7' if var=='y' else 'layer' if var=='z' else ''
    # r = [-3.15,3.15] if var=='x' else [-3,3] if var=='y' else [-5,5]
    axis=dict(showbackground=False,
              showline=True,
              zeroline=False,
              showgrid=True,
              showticklabels=True,
            #   range=1.5*np.array(r),
              title=vartitle
              )
    return axis

layout = go.Layout(
    title="Event "+str(event_idx),
    width=600,
    height=600,
    showlegend=False,
    scene=dict(
            xaxis=dict(getaxis('x')),
            yaxis=dict(getaxis('y')),
            zaxis=dict(getaxis('z')),
            aspectratio=dict(x=1, y=1, z=2),
            camera=dict(
                projection=dict(type='orthographic')
                ),
    ),
    margin=dict(
        t=100
    ),
    hovermode='closest',
    dragmode='pan',
    scene_xaxis_visible=True, scene_yaxis_visible=True, scene_zaxis_visible=True,
    legend=dict(font=dict(size=10),orientation='h'),
    )

### PLOT!

In [107]:
data=[edge_trace,cell_trace,particle_trace]

for proj_trace in proj_traces:
    data.append(proj_trace)

fig=go.Figure(data=data, layout=layout)

fig.write_html("pflow_display.html")
iplot(fig, filename='pflow_graph')