In [None]:
import importlib
import os.path
import sys

print(sys.executable)

dirs_to_try = [
    ".",
    "/gpfs/slac/staas/fs1/g/neutrino/jwolcott/app",
    "/media/hdd1/jwolcott/app",
    "/dune/app/users/jwolcott/dunesoft",
]

modules_required = {
    # module name -> subdir path
    "mlreco": "lartpc_mlreco3d",
    "larcv": "larcv2/python",
}

for module_name, module_path in modules_required.items():
    software_dir = None
    for d in dirs_to_try:
        d = os.path.join(d, module_path)
        if os.path.isdir(d):
            software_dir = d
            break

    success = False
    if software_dir:
        sys.path.insert(0, software_dir)
        try:
            importlib.import_module(module_name)
            success = True
        except:
            pass

    if not success:
        print("ERROR: couldn't find %s package" % module_name)
    else:
        print("Setup of %s ok from:" % module_name, software_dir)

In [None]:
import numpy as np
#import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=False)
import yaml

## Configuration

In [None]:
cfg='''
iotool:
  batch_size: 32
#  batch_size: 1
  shuffle: False
  num_workers: 1
  collate_fn: CollateSparse
  sampler:
    name: SequentialBatchSampler
  dataset:
    name: LArCVDataset
    data_keys:
#      - /gpfs/slac/staas/fs1/g/neutrino/kterao/data/mpvmpr_2020_01_v04/test.root
#      - /gpfs/slac/staas/fs1/g/neutrino/jwolcott/data/nd-lar-reco/supera/nd.fhc.0.supera.root
       - /media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/supera/nd.fhc.0.supera.root
    limit_num_files: 10
    schema:
      input_data:
        - parse_sparse3d_scn
#        - sparse3d_pcluster
        - sparse3d_geant4
      segment_label:
        - parse_sparse3d_scn
        - sparse3d_geant4_semantics
#       particles_label:
#         - parse_particle_points
#         - sparse3d_geant4
#        - sparse3d_pcluster
#        - particle_corrected
model:
  name: uresnet_ppn_chain
  modules:
    ppn:
      num_strides: 6      # orig: 6
      filters: 16
      num_classes: 5
      data_dim: 3
      downsample_ghost: False
      use_encoding: False
      ppn_num_conv: 1
      score_threshold: 0.5
      ppn1_size: 24
      ppn2_size: 96
      spatial_size: 768    # orig: 768
    uresnet_lonely:
      freeze: False
      num_strides: 6
      filters: 16
      num_classes: 5
      data_dim: 3
      spatial_size: 24576    # orig: 768
      ghost: False
      features: 1
  network_input:
    - input_data
  #  - particles_label
  # loss_input:
  #   - segment_label
  #   - particles_label
trainval:
  seed: 123
  learning_rate: 0.001
  unwrapper: unwrap_3d_scn
  gpus: ''
#  gpus: '0'
  weight_prefix: weights/snapshot
  iterations: 10
  report_step: 1
  checkpoint_step: 100
  log_dir: log_inference
#  model_path: '/gpfs/slac/staas/fs1/g/neutrino/jwolcott/data/nd-lar-reco/weights_kazu_sample.ckpt'
  model_path: '/media/hdd1/jwolcott/data/dune/nd/nd-lar-reco/weights_kazu_sample.ckpt'
  train: False
  debug: True
'''

## Configure

In [None]:
from mlreco.main_funcs import process_config, inference
cfg_dict=yaml.load(cfg,Loader=yaml.Loader)
# pre-process configuration (checks + certain non-specified default settings)
process_config(cfg_dict)


In [None]:
from mlreco.main_funcs import process_config, prepare
# prepare function configures necessary "handlers"
hs=prepare(cfg_dict)

In [None]:
data,output=hs.trainer.forward(hs.data_io_iter)
print("done evaluating")

## Visualize the output!


In [None]:
print(output.keys())

In [None]:

import plotly.graph_objs as go

from mlreco.visualization import scatter_points, plotly_layout3d
from mlreco.utils.ppn import uresnet_ppn_type_point_selector
from larcv import larcv
# Plot a specific entry
entry=24


# Retrieve data
vox   = data  ['input_data'    ][entry]
label = data  ['segment_label' ][entry]
pred  = output['segmentation'  ][entry]
ppn   = output['points'        ][entry]


layout = go.Layout(
    showlegend=True,
    legend=dict(x=1.01,y=0.95),
    width=768,
    height=768,
    hovermode='closest',
    margin=dict(l=0,r=0,b=0,t=0),                                                                                                                                  
    template='plotly_dark',                                                                                                                                        
    uirevision = 'same',
    scene = dict(xaxis = dict(nticks=10, range = (vox[:,0].min(),vox[:,0].max()), showticklabels=True, title='x'),
                 yaxis = dict(nticks=10, range = (vox[:,1].min(),vox[:,1].max()), showticklabels=True, title='y'),
                 zaxis = dict(nticks=10, range = (vox[:,2].min(),vox[:,2].max()), showticklabels=True, title='z'),
                 aspectmode='cube')
)


# Plot energy depositions (input data)
thresh=0
vox_thresh=vox[vox[:,4]>thresh]
trace  = scatter_points(vox_thresh,markersize=1.5,color=vox_thresh[:,4],colorscale='Jet',
                        cmin=0.01, cmax=1.5,
                        hovertext=['%.2f MeV' % v for v in vox_thresh[:,4]])
trace[-1].name = 'energy'

# Plot semantic labels ... add hover text for semantic types
labels = {}
for name in ['Michel','Track','Shower','LEScatter','Delta', 'Ghost', 'Unknown']:
    labels[getattr(larcv,'kShape%s' % name)] = name

import plotly.express as px
f = px.histogram(label[:, 4])
f.show()

trace += scatter_points(label,markersize=1.5,color=label[:,4],colorscale='Jet',
                        cmin=0,cmax=4,
                        hovertext=[labels[int(v)] for v in label[:,4]])
trace[-1].name = 'label'

# Plot semantic labels ... add hover text for semantic types
trace += scatter_points(label,markersize=1.5,color=np.argmax(pred,axis=1),colorscale='Jet',
                        cmin=0,cmax=4,
                        hovertext=[labels[v] for v in np.argmax(pred,axis=1)])
trace[-1].name = 'prediction'

# Plot points of interest from PPN
trace += scatter_points(ppn[:, :3], markersize=5,) #color=vox[:,4],colorscale='Jet',
                        #cmin=0.01, cmax=1.5,
                        #hovertext=['%.2f MeV' % v for v in vox[:,4]])
trace[-1].name = 'points'


# show
fig = go.Figure(data=trace,layout=layout)
fig.update_layout(legend=dict(x=1.1, y=0.9))
#iplot(fig)
fig.show()


In [None]:

import plotly.graph_objs as go

from mlreco.visualization import scatter_points, plotly_layout3d
from mlreco.utils.ppn import uresnet_ppn_type_point_selector
from larcv import larcv
# Plot a specific entry
entry=24


# Retrieve data
vox   = data  ['input_data'    ][entry]
label = data  ['segment_label' ][entry]
pred  = output['segmentation'  ][entry]
ppn   = output['points'        ][entry]


layout = go.Layout(
    showlegend=True,
    legend=dict(x=1.01,y=0.95),
    width=768,
    height=768,
    hovermode='closest',
    margin=dict(l=0,r=0,b=0,t=0),                                                                                                                                  
    template='plotly_dark',                                                                                                                                        
    uirevision = 'same',
    scene = dict(xaxis = dict(nticks=10, range = (vox[:,0].min(),vox[:,0].max()), showticklabels=True, title='x'),
                 yaxis = dict(nticks=10, range = (vox[:,1].min(),vox[:,1].max()), showticklabels=True, title='y'),
                 zaxis = dict(nticks=10, range = (vox[:,2].min(),vox[:,2].max()), showticklabels=True, title='z'),
                 aspectmode='cube')
)


# Plot energy depositions (input data)
thresh=0
vox_thresh=vox[vox[:,4]>thresh]
trace  = scatter_points(vox_thresh,markersize=1.5,color=vox_thresh[:,4],colorscale='Jet',
                        cmin=0.01, cmax=1.5,
                        hovertext=['%.2f MeV' % v for v in vox_thresh[:,4]])
trace[-1].name = 'energy'

# Plot semantic labels ... add hover text for semantic types
labels = {}
for name in ['Michel','Track','Shower','LEScatter','Delta', 'Ghost', 'Unknown']:
    labels[getattr(larcv,'kShape%s' % name)] = name

import plotly.express as px
f = px.histogram(label[:, 4])
f.show()

trace += scatter_points(label,markersize=1.5,color=label[:,4],colorscale='Jet',
                        cmin=0,cmax=4,
                        hovertext=[labels[int(v)] for v in label[:,4]])
trace[-1].name = 'label'

# Plot semantic labels ... add hover text for semantic types
trace += scatter_points(label,markersize=1.5,color=np.argmax(pred,axis=1),colorscale='Jet',
                        cmin=0,cmax=4,
                        hovertext=[labels[v] for v in np.argmax(pred,axis=1)])
trace[-1].name = 'prediction'

# Plot points of interest from PPN
trace += scatter_points(ppn[:, :3], markersize=5,) #color=vox[:,4],colorscale='Jet',
                        #cmin=0.01, cmax=1.5,
                        #hovertext=['%.2f MeV' % v for v in vox[:,4]])
trace[-1].name = 'points'


# show
fig = go.Figure(data=trace,layout=layout)
fig.update_layout(legend=dict(x=1.1, y=0.9))
#iplot(fig)
fig.show()
