### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [1]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `nugraph` and `pynuml` packages imported in order to plot, and `torch` for some tensor operations later on

In [2]:
import os
import sys
#sys.path.append(f"{os.environ['HOME']}/nugraph/pynuml")
sys.path.append(f"{os.environ['HOME']}/nugraph/nugraph")
os.environ["NUGRAPH_DIR"] = f"{os.environ['HOME']}/nugraph"
os.environ["NUGRAPH_LOG"] = f"{os.environ['HOME']}/logs"
os.environ["NUGRAPH_DATA"] = f"{os.environ['HOME']}/data"
import nugraph as ng
import pynuml
import torch

### Set model and data to use

This allows the user to switch out different model architectures and datasets

In [3]:
Data = ng.data.NuGraphDataModule
Model = ng.models.NuGraph3

### Configure data module
Declare a data module. If you're working on a standard cluster, the data file location should be configured automatically. If not, you'll need to configure it manually.

In [4]:
#nudata = Data(model=Model)
nudata = ng.data.H5DataModule(data_path="/exp/dune/data/users/hrazafin/iceberg/merged_run9_hdf5/processed_stacked/processed_sg_bg.h5")

### Configure network
In order to test a trained model, we instantiate it using a checkpoint file. These are produced during training, so if you've trained a model, there should be an associated checkpoint in your output directory that you can pass here.

In [5]:
#ckpt = os.path.expandvars("/exp/dune/data/users/hrazafin/iceberg/training_ckpt/epoch=79-step=57840.ckpt")
ckpt = os.path.expandvars("/home/hrazafin/logs/runAug2225_sgbg_instance/version_0/checkpoints/epoch=77-step=56394.ckpt")
#model = Model.load_from_checkpoint(ckpt)
model = Model.load_from_checkpoint(ckpt, map_location="cpu")
model.freeze()

In [None]:
# import torch
# from nugraph.models.nugraph3 import NuGraph3 as Model

# ckpt = "/exp/dune/data/users/hrazafin/iceberg/training_ckpt/epoch=79-step=57840.ckpt"
# #ckpt = "/home/hrazafin/logs/runAug2225_sgbg_instance/version_0/checkpoints/epoch=77-step=56394.ckpt"

# # 1. Load checkpoint dict
# ckpt_data = torch.load(ckpt, map_location="cpu")

# # 2. Build model with your data’s feature size
# #model = Model(in_features=6)
# model = Model(
#     in_features = 8, #def 4
#                  hit_features = 128,
#                  nexus_features = 32,
#                  interaction_features = 32,
#                  instance_features = 8, #def 8
#                  planes = ("u","v","y"),
#                  semantic_classes = ('MIP','HIP','shower','michel','diffuse'),
#                  event_classes = ('signal','background')
# )

# # 3. Clean checkpoint keys
# model_state = model.state_dict()
# to_delete = []

# for k, v in ckpt_data["state_dict"].items():
#     if k not in model_state:  # key not in this model
#         print(f"Skipping unexpected key: {k}")
#         to_delete.append(k)
#     elif v.shape != model_state[k].shape:  # shape mismatch
#         print(f"Dropping incompatible key: {k}, ckpt shape={v.shape}, model shape={model_state[k].shape}")
#         to_delete.append(k)

# for k in to_delete:
#     del ckpt_data["state_dict"][k]

# # 4. Load with strict=False
# missing, unexpected = model.load_state_dict(ckpt_data["state_dict"], strict=False)
# print("Missing keys:", missing)
# print("Unexpected keys:", unexpected)

# # 5. Freeze if needed
# model.freeze()


## Configure plotting utility
Instantiate the **pynuml** utility for plotting graph objects, which will do the heavy lifting for us here!

In [6]:
plot = pynuml.plot.GraphPlot(planes=nudata.planes,
                             classes=nudata.semantic_classes)

## Plot ground truth labels

### Iterable dataset

First we define an iterator over the test dataset:

In [7]:
test_iter = iter(nudata.test_dataset)

### Retrieve the next graph

This block retrieves a graph from the testing dataset, and passes it through the trained model. Since we defined `test_iter` as an iterator over the dataset, the following block can be executed multiple times, and each time it's executed, it will step to the next graph in the dataset.

In [43]:
data = next(test_iter)

#additional checks
#for k in data.keys():
#    print(k, type(data[k]))

#print(data.x)

md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'

# Concatenate positions to the 6 features
data["hit"].x = torch.cat([data["hit"].x, data["hit"].pos], dim=1)

print("Updated hit.x shape:", data["hit"].x.shape)

model(data);

Updated hit.x shape: torch.Size([430, 8])


### Plot a single graph

We can now use pynuml's plotting utilities to plot the graph as a figure. Each time you call the above block to retrieve a new graph, you can then re-execute the plotting blocks to re-plot with the new graph.

In [44]:
fig = plot.plot(data, target='instance', how='true')
fig

FigureWidget({
    'data': [{'customdata': array([['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'diffuse', '-1'],
                                   ['michel', '1', 'diffuse', '-1'],
                                   ['michel', '1', 'diffuse', '-1'],
                                   ['michel', '1', 'diffuse', '-1'],
                                   ['michel', '1', 'diffuse', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', '1', 'muon', '-1'],
                                   ['michel', 

In [None]:
# fig = plot.plot(data, target='instance', how='true', filter='none')
# fig

In [45]:
fig = plot.plot(data, target='semantic', how='pred', filter='none')
fig

FigureWidget({
    'data': [{'customdata': array([['michel', '1', 'muon', '-1'],
                                   ['muon', '2', 'muon', '-1'],
                                   ['muon', '2', 'muon', '-1'],
                                   ...,
                                   ['muon', '0', 'muon', '-1'],
                                   ['muon', '0', 'muon', '-1'],
                                   ['muon', '0', 'muon', '-1']], dtype=object),
              'hovertemplate': ('semantic prediction=%{customda' ... 'tomdata[3]:.4f}<extra></extra>'),
              'legendgroup': 'muon',
              'marker': {'color': '#EF553B', 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'muon',
              'orientation': 'v',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'a287a7b6-1759-40cd-8c1e-2ed1c7633051',
              'x': {'bdata': ('AAD2QgAACEMAAAlDAAAKQwAAC0MAAA' ... '5DAAAwQwAAMUMAADJDAAAzQwAANEM='),
         

In [None]:
# Check test 1
print("hit.x shape:", data["hit"].x.shape)

In [None]:
# Check test 2
print(data)
print(data.keys)
print("hit.x:", data["hit"].x.shape)
print("hit.y:", getattr(data["hit"], "y", None))
print("event.y:", getattr(data, "y", None))

In [None]:
# Check test 3
# ckpt_path = os.path.expandvars("/exp/dune/data/users/hrazafin/iceberg/training_ckpt/epoch=79-step=57840.ckpt")

# # Load the raw checkpoint dictionary
# ckpt = torch.load(ckpt_path, map_location="cpu")

# print(ckpt.keys())  # should show ['state_dict', 'epoch', ...] or ['model_state_dict', ...]


In [None]:
# Check test 4
# state_dict = ckpt["state_dict"]

# for k, v in state_dict.items():
#     if "input_norm.norm.mean" in k:
#         print(k, v.shape)
#     if "input_norm.norm.var" in k:
#         print(k, v.shape)

In [None]:
# Check test 5
# print("hit.x shape:", data["hit"].x.shape)  # currently [339, 6]
# print("hit.pos shape:", data["hit"].pos.shape)  # should be [339, 2]

In [None]:
#fig = plot.plot(data, target='semantic', how='pred', filter='none')

fig = plot.plot(data, target='instance', how='true')
fig

In [None]:
#fig = plot.plot(data, target='semantic', how='pred', filter='none')
fig = plot.plot(data, target='semantic', how='pred', filter='none')
fig

### Save plots to disk

We can also use plotly's `write_html` and `write_image` methods to print the figure as an interactive webpage, or in a raster graphics (ie. PNG, JPEG etc) or vector graphics (ie. PDF) format. By default this writes to a `plots` subdirectory – if you're seeing an error that this directory does not exist, simply create one, or change the path to a valid output location!

In [None]:
fig.write_html(f'plots/evd/{name}_semantic_true.html')
fig.write_image(f'plots/evd/{name}_semantic_true.png')
fig.write_image(f'plots/evd/{name}_semantic_true.pdf')

### (Optional) Select example events

The following blocks will select the representative events from the NuGraph2 paper

### Event 1

Run 5189, subrun 225, event 11300

In [None]:
data = nudata.test_dataset[64]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 2

Run 6999, subrun 11, event 595

In [None]:
data = nudata.test_dataset[36]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 3

Run 7048, subrun 177, event 8858

In [None]:
data = nudata.test_dataset[11]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 4

Run 5459, subrun 94, event 4738

In [None]:
data = nudata.test_dataset[91]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Event 5

Run 6780, subrun 200, event 10006

In [None]:
data = nudata.test_dataset[27]
md = data['metadata']
name = f'r{md.run}_sr{md.subrun}_e{md.event}'
model.step(data);

### Plot event displays

Write event displays to disk in PDF format for use in the NuGraph2 paper.

In [None]:
plot.plot(data, target='filter', how='true', filter='none').write_image(f'plots/evd/{name}_filter_true.pdf')
plot.plot(data, target='filter', how='pred', filter='none').write_image(f'plots/evd/{name}_filter_pred.pdf')
plot.plot(data, target='semantic', how='true', filter='true').write_image(f'plots/evd/{name}_semantic_true.pdf')
plot.plot(data, target='semantic', how='pred', filter='pred').write_image(f'plots/evd/{name}_semantic_pred.pdf')

### Print model performance

Print out information on the rate at which the model makes mistakes, and some information on common failure modes.

In [None]:
tf = torch.cat([(data[p].y_semantic!=-1) for p in nudata.planes])
pf = torch.cat([data[p].x_filter.round() for p in nudata.planes])
ts = torch.cat([data[p].y_semantic for p in nudata.planes])[tf]
ps = torch.cat([data[p].x_semantic.argmax(dim=1) for p in nudata.planes])[tf]

print(f'there are {tf.size(0)} hits overall, of which {tf.sum()} are signal.')

print('\n### Filter\n')

mask = tf != pf
print(f'{mask.sum()} hits were classified wrong. of those, {(tf[mask]==0).sum()} are false positives, and {(tf[mask]==1).sum()} are false negatives.')

print('\n### Semantic\n')

print(f'of the {tf.sum()} signal hits, {(ps==ts).sum()} are correctly classified.')

mask = ts != ps
print(f'of the {mask.sum()} misclassified hits:')

for i, c in enumerate(nudata.semantic_classes):
    tm = ts[mask]==i
    if tm.sum() == 0: continue
    print(f'- {tm.sum()} {c} hits were misclassified.')
    for j, cj in enumerate(nudata.semantic_classes):
        pm = ps[mask][tm]==j
        if pm.sum() == 0: continue
        print(f'  - {pm.sum()} as {cj}')