In [None]:
SUPERA_INPUT_FILE  = "/media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/supera/nd.fhc.1.supera.voxpitch=4mm.nothresh.filter-emptytens+nopart.root"
WEIGHTS_FILE = "/media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/train/track+showergnn-1file-15k/snapshot-10799.ckpt"
CONFIG_BASE  = "/media/hdd1/jwolcott/app/personal/dune/nd/nd-lar-reco/config.inference.fullchain.yaml"

BATCH_SIZE = 45

import os.path
for f in (SUPERA_INPUT_FILE, CONFIG_BASE):
    assert os.path.isfile(f), "Can't find file: " + f

In [None]:
import importlib
import os.path
import sys

print(sys.executable)

dirs_to_try = [
    ".",
    "/gpfs/slac/staas/fs1/g/neutrino/jwolcott/app",
    "/media/hdd1/jwolcott/app",
    "/dune/app/users/jwolcott/dunesoft",
]

modules_required = {
    # module name -> subdir path
    "mlreco": "lartpc_mlreco3d",
    "larcv": "larcv2/python",
}

for module_name, module_path in modules_required.items():
    software_dir = None
    for d in dirs_to_try:
        d = os.path.join(d, module_path)
        if os.path.isdir(d):
            software_dir = d
            break

    success = False
    if software_dir:
        sys.path.insert(0, software_dir)
        try:
            importlib.import_module(module_name)
            success = True
        except:
            pass

    if not success:
        print("ERROR: couldn't find %s package" % module_name)
    else:
        print("Setup of %s ok from:" % module_name, software_dir)

In [None]:
import numpy as np
#import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=False)
import yaml

## Configuration

In [None]:
from load_helpers import LoadConfig
cfg_dict = LoadConfig(CONFIG_BASE,
                      input_files=[SUPERA_INPUT_FILE],
                      model_file=WEIGHTS_FILE,
                      batch_size=BATCH_SIZE,
                      use_gpu=True)

## Configure

In [None]:
import itertools
def convert_to_geom_coords(values, metadata, evnums=[]):
    metadata = metadata[0]  # they are all the same
    # for coord in ("x", "y", "z"):
    #     print("min", coord, "=", getattr(metadata, "min_%s" % coord)())
    #     print ("voxel size", coord, "=",  getattr(metadata, "size_voxel_%s" % coord)())
    if len(evnums) > 0:
        values = itertools.compress(values, (i in evnums for i in range(len(values)) ))
    for ev in values:
        ev[:, 0] = ev[:, 0] * metadata.size_voxel_x() + metadata.min_x()
        ev[:, 1] = ev[:, 1] * metadata.size_voxel_y() + metadata.min_y()
        ev[:, 2] = ev[:, 2] * metadata.size_voxel_z() + metadata.min_z()



In [None]:
from mlreco.main_funcs import process_config, prepare
# prepare function configures necessary "handlers"
hs=prepare(cfg_dict)

In [None]:
def cycle(data_io):
    for x in data_io:
        yield x

it = iter(cycle(hs.data_io))

data,output=hs.trainer.forward(it)
print({k: v for k, v in output.items()})
print("done evaluating")

## Visualize the output!



In [None]:
import pprint
print(data.keys())
pprint.pprint(sorted(output.keys()))


# print(len(output["clust_fragments"][1]))
# print(len(output["clust_frag_seg"][1]))
# #print(len(output["clust_frag_batch_ids"]))
# print(output["fragments"][3])
# print(output["frag_group_pred"][3])
# print(len(output["frag_node_pred"]))
# print(len(output["frag_edge_pred"]))


import pprint
#pprint.pprint(data["particles_label"])

In [None]:
import numpy
from mlreco.utils.ppn import uresnet_ppn_type_point_selector

# there's post-processing that needs to be done with PPN before we transform coordinates
ppn = [None,] * len(data["input_data"])
for entry in range(len(data["input_data"])):
    print("ppn-post: for entry", entry, "input_data has length", len(data['input_data'][entry]))
    ppn[entry] = uresnet_ppn_type_point_selector(data['input_data'][entry],
                                                 output,
                                                 entry=entry,
                                                 score_threshold=0.5,
                                                 type_threshold=2)  # latter two args are from Laura D...
# print(ppn[1].dtype)
# print(ppn[1])

#output["ppn_post"] = numpy.concatenate(ppn, axis=0)
output["ppn_post"] = ppn
for entry, ppn_points in enumerate(ppn):
    print("there are", numpy.count_nonzero(ppn_points[:, -1] == 1), "'track' PPN points in entry", entry)

In [None]:
import pprint
print(len(data["particles_label"]))
print(len(data["segment_label"]))
print(len(data["input_data"]))
#pprint.pprint(data["particles_label"][2])

In [None]:
convert_list = ["input_data", "segment_label", "ppn_post", "particles_label"]

for collection in (data, output):
    for key in collection:
        if key not in convert_list:
            continue

        vals = collection[key]
        #print(key)
        sys.stdout.flush()
        # print(collection, "before:")
        # print(vals[0])
        convert_to_geom_coords(vals, data["metadata"])
        # print(collection, "after:")
        # print(vals[0])



In [None]:

def collection_range(coord, *arrays):
    """
    Get the pair of (min, max) extrema over a collection of arrays, using just the indicated coordinate.
    :param coord: which coordinate to do it over
    :param arrays: the arrays to be compared
    :return: tuple (min, max) of extrema found
    """

    arrays = [a for a in arrays if len(a) > 0]
    return ( min(a[:,coord].min() for a in arrays),
             max(a[:,coord].max() for a in arrays) )

In [None]:
import pprint
#print(len(data["metadata"]))
#print(output["points"][4])
#print(data["input_data"][0])
# pprint.pprint([(p.pdg_code(), 
#                 ["%g" % getattr(p.first_step().as_point3d(), coord) for coord in ("x", "y", "z")],
#                 ["%g" % getattr(p.last_step().as_point3d(), coord) for coord in ("x", "y", "z")]) 
#                for p in data["particles"][3]])
#print(data["particles_label"][3])

#print(data['input_data'][0])

#print(output["ppn_post"][0])
#print(output["track_fragments"][0])
#print(output["track_group_pred"][0])

In [None]:

import numpy
import plotly.graph_objs as go


from mlreco.visualization import scatter_points, plotly_layout3d
from mlreco.visualization.gnn import network_topology
from larcv import larcv

markersize = 2  # pixels...

# Plot a specific entry
entry=2


# Retrieve data
vox       = data  ['input_data'      ][entry]
label     = data  ['segment_label'   ][entry]
pred      = output['segmentation'    ][entry]
ppn       = output['ppn_post'        ][entry]
clus      = output['clust_fragments' ][entry]
clus_seg  = output['clust_frag_seg'  ][entry]
tracks    = output['track_fragments' ][entry]
trk_grp   = output['track_group_pred'][entry]
show_grp  = output['frag_group_pred' ][entry]
showers   = output['fragments'       ][entry]

particles =       data['particles'][entry]
particle_points = data['particles_label'][entry]

#print(output.keys())
#print(numpy.unique(ppn[:, 3]))
#print(ppn)

# we want to show all of each type of point
arrays = (vox, label, pred, ppn, particle_points)

layout = go.Layout(
    showlegend=True,
    legend=dict(x=1.01,y=0.95),
    width=600,
    height=600,
    hovermode='closest',
    margin=dict(l=0,r=0,b=0,t=0),                                                                                                                                  
    template='plotly_dark',                                                                                                                                        
    uirevision = 'same',
    scene = dict(xaxis = dict(nticks=10, range = collection_range(0, *arrays), showticklabels=True, title='x (cm)'),
                 yaxis = dict(nticks=10, range = collection_range(1, *arrays), showticklabels=True, title='y (cm)'),
                 zaxis = dict(nticks=10, range = collection_range(2, *arrays), showticklabels=True, title='z (cm)'),
                 aspectmode='cube')
)


# Plot energy depositions (input data)
thresh=0 #0.01
saturate=5
color_min=thresh
color_max=saturate
vox_thresh=vox[vox[:,4]>thresh]
markersize=numpy.tanh(vox_thresh[:,4])*3
vox_E_saturate = numpy.minimum(vox_thresh[:,4], saturate)
#vox_E_saturate = numpy.full_like(vox_thresh[:,4], color_min)  # use this to make all edep colors white
trace  = scatter_points(vox_thresh,markersize=markersize,symbol="square",color=vox_E_saturate,colorscale='Reds',
                        cmin=color_min, cmax=color_max,
                        hovertext=['%.2f MeV' % v for v in vox_thresh[:,4]])
trace[-1].name = 'energy'

# Plot semantic labels ... add hover text for semantic types
labels = {}
for name in ['Michel','Track','Shower','LEScatter','Delta', 'Ghost', 'Unknown']:
    labels[getattr(larcv,'kShape%s' % name)] = name

print(labels)
    
# import plotly.express as px
# f = px.histogram(label[:, 4])
# f.show()

trace += scatter_points(label,markersize=markersize,symbol="square",color=label[:,4],colorscale='Jet',
                        cmin=0,cmax=4,
                        hovertext=[labels[int(v)] for v in label[:,4]])
trace[-1].name = 'truth'

# Plot semantic labels ... add hover text for semantic types
trace += scatter_points(label,markersize=markersize,symbol="square",color=np.argmax(pred,axis=1),colorscale='Jet',
                        cmin=0,cmax=4,
                        hovertext=[labels[v] for v in np.argmax(pred,axis=1)])
trace[-1].name = 'prediction'

# Plot points of interest from PPN
trace += scatter_points(ppn[:, :3], symbol="diamond", markersize=3,
                        color=ppn[:, -1], cmin=0, cmax=5,  # type
                        colorscale="mygbm",
                        hovertext=[labels[v] for v in ppn[:, -1]]    #ppn[:, 5], # score

                        ) #color=vox[:,4],colorscale='Jet',
                        #cmin=0.01, cmax=1.5,
                        #hovertext=['%.2f MeV' % v for v in vox[:,4]])
trace[-1].name = 'points'

# truth points
trace += scatter_points(particle_points, markersize=3, symbol="circle",
                        color=particle_points[:, 4], cmin=0, cmax=5, colorscale="mygbm",
                        hovertext=[labels[v] for v in particle_points[:, 4]])
trace[-1].name = "True point labels"
# trace[-1].marker.colorscale= ['cyan', 'rgb(255,234,0)', 'rgb(127, 188, 65)', 'purple', 'rgb(255,111,0)']

#print(trace)

#trace = []
colors = {
    11:   "orange",
    12:   "black",
    13:   "blue",
    14:   "black",
    22:   "yellow",
    111:  "white",
    211:  "purple",
    321:  "cyan",
    2112: "white",
    2212: "green",

}
vals = dict([(t, []) for t in ("x", "y", "z", "line_color", "text")])
for particle in particles:
#     if particle.last_step().as_point3d().distance(particle.first_step().as_point3d()) < 4:
#         continue
    if abs(particle.pdg_code()) > 1000000000:
        colors[abs(particle.pdg_code())] = "gray"

    vals["line_color"].append(colors[abs(particle.pdg_code())])
    vals["text"].append("pdg=" + str(particle.pdg_code()))
    # to make same length as values, just duplicate the last one since None will cause issues
    for attr in ("line_color", "text"):
        for i in range(2):
            vals[attr].append(vals[attr][-1])

    for coord in ("x", "y", "z"):
        for step in ("first_step", "last_step"):
            vals[coord].append(getattr(getattr(particle, "%s" % step)(), coord)())
        
        # separator
        vals[coord].append(None)

#print(colors)
trace.append(go.Scatter3d(vals, mode="lines", line_dash="dot", line_width=3, hovertext=vals["text"]))
#    break
trace[-1].name = "True trajs"

# show all the fragments
trace += network_topology(vox, clus, edge_index=[],
                          clust_labels=range(len(clus)), edge_labels=[],
                          mode='scatter', markersize=2, linewidth=2,
                          colorscale='mygbm',
                          cmin=0,
                          cmax=0 if len(clus) == 0 else len(clus))
trace[-1].name = "All fragments"

#show only regrouped track fragments
trace += network_topology(vox, tracks, edge_index=[],
                          clust_labels=trk_grp,
                          edge_labels=[],
                          mode='scatter', markersize=2, linewidth=2,
                          colorscale='mygbm',
                          cmin=0 if len(trk_grp) == 0 else min(trk_grp),
                          cmax=0 if len(trk_grp) == 0 else max(trk_grp)+1)
trace[-1].name = "Regrouped track"

#show only regrouped EM fragments
trace += network_topology(vox, showers, edge_index=[],
                          clust_labels=show_grp, edge_labels=[],
                          mode='scatter', markersize=2, linewidth=2,
                          colorscale='mygbm',
                          cmin=0 if len(show_grp) == 0 else min(show_grp),
                          cmax=0 if len(show_grp) == 0 else max(show_grp))
trace[-1].name = "Regrouped shower"

print(len([p for p in particles if p.shape() == 1]), "'track' particles")

# show
fig = go.Figure(data=trace,layout=layout)
fig.update_layout(legend=dict(x=1.1, y=0.9))
#iplot(fig)
fig.show()