### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [None]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `nugraph` and `pynuml` packages imported in order to plot, and `torch` for some tensor operations later on

In [None]:
import os
import nugraph as ng
import pynuml
import torch

### Configure data module
Declare a data module. If you're working on a standard cluster, the data file location should be configured automatically. If not, you'll need to configure it manually.

In [None]:
nudata = ng.data.H5DataModule()

### Configure network
In order to test a trained model, we instantiate it using a checkpoint file. These are produced during training, so if you've trained a model, there should be an associated checkpoint in your output directory that you can pass here.

In [None]:
ckpt = os.path.expandvars("$NUGRAPH_DATA/uboone-opendata/hierarchical.ckpt")
model = ng.models.NuGraph3.load_from_checkpoint(ckpt, map_location="cpu")
model.freeze()

## Configure plotting utility
Instantiate the **pynuml** utility for plotting graph objects, which will do the heavy lifting for us here!

In [None]:
plot = pynuml.plot.GraphPlot(planes=nudata.planes,
                             classes=nudata.semantic_classes)

## Plot ground truth labels

### Iterable dataset

First we define an iterator over the test dataset:

In [None]:
test_iter = iter(nudata.test_dataset)

### Retrieve the next graph

This block retrieves a graph from the testing dataset, and passes it through the trained model. Since we defined `test_iter` as an iterator over the dataset, the following block can be executed multiple times, and each time it's executed, it will step to the next graph in the dataset.

In [None]:
data = next(test_iter)
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model(data);

### Plot a single graph

We can now use pynuml's plotting utilities to plot the graph as a figure. Each time you call the above block to retrieve a new graph, you can then re-execute the plotting blocks to re-plot with the new graph.

In [None]:
fig = plot.plot(data, target='semantic', how='true', filter='show')
fig

### Save plots to disk

We can also use plotly's `write_html` and `write_image` methods to print the figure as an interactive webpage, or in a raster graphics (ie. PNG, JPEG etc) or vector graphics (ie. PDF) format. By default this writes to a `plots` subdirectory – if you're seeing an error that this directory does not exist, simply create one, or change the path to a valid output location!

In [None]:
fig.write_html(f'plots/evd/{name}_semantic_true.html')
fig.write_image(f'plots/evd/{name}_semantic_true.png')
fig.write_image(f'plots/evd/{name}_semantic_true.pdf')

### (Optional) Select example events

The following blocks will select the representative events from the NuGraph2 paper

### Event 1

Run 5189, subrun 225, event 11300

In [None]:
data = nudata.test_dataset[64]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 2

Run 6999, subrun 11, event 595

In [None]:
data = nudata.test_dataset[36]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 3

Run 7048, subrun 177, event 8858

In [None]:
data = nudata.test_dataset[11]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 4

Run 5459, subrun 94, event 4738

In [None]:
data = nudata.test_dataset[91]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 5

Run 6780, subrun 200, event 10006

In [None]:
data = nudata.test_dataset[27]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Plot event displays

Write event displays to disk in PDF format for use in the NuGraph2 paper.

In [None]:
plot.plot(data, target='filter', how='true', filter='none').write_image(f'plots/evd/{name}_filter_true.pdf')
plot.plot(data, target='filter', how='pred', filter='none').write_image(f'plots/evd/{name}_filter_pred.pdf')
plot.plot(data, target='semantic', how='true', filter='true').write_image(f'plots/evd/{name}_semantic_true.pdf')
plot.plot(data, target='semantic', how='pred', filter='pred').write_image(f'plots/evd/{name}_semantic_pred.pdf')

### Print model performance

Print out information on the rate at which the model makes mistakes, and some information on common failure modes.

In [None]:
tf = torch.cat([(data[p].y_semantic!=-1) for p in nudata.planes])
pf = torch.cat([data[p].x_filter.round() for p in nudata.planes])
ts = torch.cat([data[p].y_semantic for p in nudata.planes])[tf]
ps = torch.cat([data[p].x_semantic.argmax(dim=1) for p in nudata.planes])[tf]

print(f'there are {tf.size(0)} hits overall, of which {tf.sum()} are signal.')

print('\n### Filter\n')

mask = tf != pf
print(f'{mask.sum()} hits were classified wrong. of those, {(tf[mask]==0).sum()} are false positives, and {(tf[mask]==1).sum()} are false negatives.')

print('\n### Semantic\n')

print(f'of the {tf.sum()} signal hits, {(ps==ts).sum()} are correctly classified.')

mask = ts != ps
print(f'of the {mask.sum()} misclassified hits:')

for i, c in enumerate(nudata.semantic_classes):
    tm = ts[mask]==i
    if tm.sum() == 0: continue
    print(f'- {tm.sum()} {c} hits were misclassified.')
    for j, cj in enumerate(nudata.semantic_classes):
        pm = ps[mask][tm]==j
        if pm.sum() == 0: continue
        print(f'  - {pm.sum()} as {cj}')