### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [None]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `pytorch_lightning` and `nugraph` packages imported in order to train

In [None]:
import nugraph as ng
import pytorch_lightning as pl

import torch
import tqdm
import torchmetrics as tm
import plotly.graph_objects as go

### Set default plotting options

Define a dictionary containing all standard plotting options that we want to set for all the plots we draw

In [None]:
style = {
    'layout_width': 800,
    'layout_height': 450,
    'layout_margin_b': 20,
    'layout_margin_t': 20,
    'layout_margin_r': 20,
    'layout_margin_l': 20,
    'layout_xaxis_title_font_size': 24,
    'layout_xaxis_tickfont_size': 20,
    'layout_yaxis_title_font_size': 24,
    'layout_yaxis_tickfont_size': 20,
    'layout_legend_font_size': 20,
}

### Define label score metrics

Define a torchmetrics class to make true & false score distributions for each semantic label

In [None]:
class Score(tm.Metric):
    def __init__(self,
                 num_classes: int,
                 bins: int = 20,
                 range: tuple[float] = (0,1),
                 ignore_index: int = None):
        super().__init__()

        self.bins = bins
        self.range = range
        self.ignore_index = ignore_index
        
        self.add_state('true', default=torch.zeros(num_classes, bins), dist_reduce_fx = 'sum')
        self.add_state('false', default=torch.zeros(num_classes, bins), dist_reduce_fx = 'sum')

    def update(self, preds: torch.Tensor, target: torch.Tensor):

        # check that number of classes is correct
        num_classes = preds.size(1)
        assert num_classes == self.true.size(0)

        # loop over labels
        filter = (target != self.ignore_index)
        for label in range(num_classes):
            mask = filter & (target == label)
            hist, bin_edges = preds[mask, label].histogram(bins=self.bins,
                                                           range=self.range)
            self.true[label] += hist
            mask = filter & (target != label)
            hist, bin_edges = preds[mask, label].histogram(bins=self.bins,
                                                           range=self.range)
            self.false[label] += hist
        
    def compute(self):
        true = self.true / self.true.sum(dim=1)[:,None]
        false = self.false / self.false.sum(dim=1)[:,None]
        return true, false

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [None]:
nudata = ng.data.H5DataModule(data_path='/raid/uboone/NuGraph2/NG2-paper.gnn.h5', batch_size=64)

### Configure network
In order to test a trained model, we instantiate it using a checkpoint file. These are produced during training, so if you've trained a model, there should be an associated checkpoint in your output directory that you can pass here.

In [None]:
model = ng.models.NuGraph2.load_from_checkpoint('/raid/uboone/NuGraph2/NG2-paper.ckpt', map_location='cpu')
model.freeze()

### Declare trainer and run testing
First we set the testing device. In order to test with a GPU (recommended), pass an integer specifying the index of the GPU to use to the `configure_device()` function; otherwise, this block defaults to CPU testing. We then instantiate a PyTorch Lightning trainer that we'll use for testing, and then run the testing stage, which iterates over all batches in the test dataset and prints performance metrics.

In [None]:
accelerator, devices = ng.util.configure_device()
trainer = pl.Trainer(accelerator=accelerator,
                     devices=devices,
                     logger=False)

### Calculate testing metrics

Loop over each batch and produce testing plots: score distributions and ROC curves per label.

In [None]:
num_classes = len(nudata.semantic_classes)

score = Score(num_classes=num_classes, ignore_index=-1)
roc_filter = tm.classification.BinaryROC(
    thresholds=1000,
)
roc_semantic = tm.classification.MulticlassROC(
    num_classes=num_classes,
    thresholds=1000,
    ignore_index=-1,
)

batches = trainer.predict(model, nudata.test_dataloader())

for batch in tqdm.tqdm(batches):
    for p in nudata.planes:
        score.update(batch[p].x_semantic, batch[p].y_semantic)
        roc_filter.update(batch[p].x_filter, batch[p].y_semantic!=-1)
        roc_semantic.update(batch[p].x_semantic, batch[p].y_semantic)

true, false = score.compute()
fpr_filter, tpr_filter, thresholds = roc_filter.compute()
fpr_semantic, tpr_semantic, thresholds = roc_semantic.compute()

### Filter ROC curve

Draw ROC curve for filter decoder

In [None]:
fig = go.Figure(
    layout_xaxis_title='False positive rate',
    layout_xaxis_range=(0,0.15),
    layout_yaxis_title='True positive rate',
    layout_yaxis_range=(0.7,1),
    layout_legend_xanchor='right',
    layout_legend_x=0.9,
    layout_legend_yanchor='bottom',
    layout_legend_y=0.1,
    **style,
)
fig.add_scatter(x=fpr_filter, y=tpr_filter)
fig.write_image('plots/roc-filter.pdf')
fig.write_image('plots/roc-filter.png')
fig

### Plot semantic ROC curves

Draw semantic ROC curves for each semantic class on the same axes

In [None]:
fig = go.Figure(
    layout_xaxis_title='False positive rate',
    layout_xaxis_range=(0,0.15),
    layout_yaxis_title='True positive rate',
    layout_yaxis_range=(0.7,1),
    layout_legend_xanchor='right',
    layout_legend_x=0.9,
    layout_legend_yanchor='bottom',
    layout_legend_y=0.1,
    **style,
)
for label, name in enumerate(nudata.semantic_classes):
    fig.add_scatter(x=fpr_semantic[label], y=tpr_semantic[label], name=name)
fig.write_image('plots/roc-semantic.pdf')
fig.write_image('plots/roc-semantic.png')
fig

### Plot score distributions

For true and false predictions, draw the score distributions for each semantic class on the same axes

In [None]:
bins = torch.linspace(0, 1, 21)
bin_centers = 0.5 * (bins[:-1] + bins[1:])

for name, y in (('true',true),('false',false)):

    fig = go.Figure(
        layout_xaxis_title=f'Predicted {name} score',
        layout_yaxis_title='# of hits (area-normed)',
        layout_yaxis_dtick=1,
        layout_yaxis_type='log',
        layout_legend_yanchor='top',
        layout_legend_y=0.9,
        layout_legend_xanchor='center',
        layout_legend_x=0.5,
        **style
    )

    for i, label in enumerate(nudata.semantic_classes):
        fig.add_scatter(x=bin_centers, y=y[i], name=label, line_shape='spline')

    fig.write_image(f'plots/score-{name}.pdf')
    fig.write_image(f'plots/score-{name}.png')

    fig.show()