### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [None]:
%load_ext autoreload
%autoreload 2

### Import packages
We'll need the `pytorch_lightning` and `nugraph` packages imported in order to run inference, the `time` package to measure inference time, and `plotly.graph_objects` to plot.

In [None]:
import time
import pytorch_lightning as pl
import nugraph as ng
import plotly.graph_objects as go

### Set default plotting options

Define a dictionary containing all standard plotting options that we want to set for all the plots we draw

In [None]:
style = {
    'layout_width': 800,
    'layout_height': 450,
    'layout_margin_b': 20,
    'layout_margin_t': 20,
    'layout_margin_r': 20,
    'layout_margin_l': 20,
    'layout_xaxis_title_font_size': 24,
    'layout_xaxis_tickfont_size': 20,
    'layout_yaxis_title_font_size': 24,
    'layout_yaxis_tickfont_size': 20,
    'layout_legend_font_size': 24,
}

### Configure network
In order to test a trained model, we instantiate it using a checkpoint file. These are produced during training, so if you've trained a model, there should be an associated checkpoint in your output directory that you can pass here.

In [None]:
model = ng.models.NuGraph2.load_from_checkpoint('/raid/uboone/NuGraph2/NG2-paper', map_location='cpu')
model.freeze()

### Benchmark inference time
Loop over a range of batch size options. For each one, run testing and record how long it took.

In [None]:
accelerator, devices = ng.util.configure_device(0)
x = []
y = []
for i in range(9):
    batch_size = pow(2, i)
    x.append(batch_size)
    nudata = ng.data.H5DataModule(
        data_path='/raid/uboone/NuGraph2/NG2-paper.gnn.h5',
        batch_size=batch_size,
    )
    accelerator, devices = ng.util.configure_device(0)
    trainer = pl.Trainer(accelerator=accelerator,
                         devices=devices, logger=False)
    t0 = time.time()
    trainer.test(model, datamodule=nudata)
    y.append((time.time()-t0)/len(nudata.test_dataset))

### Plot inference time
Draw a scatter plot using the batch sizes and inference times from the previous step. Save the resulting plot to disk and visualize it.

In [None]:
fig = go.Figure(
    layout_xaxis_title='Batch size',
    layout_yaxis_title='Inference time per graph [s]',
    layout_xaxis_type='log',
    layout_xaxis_tickmode='array',
    layout_xaxis_ticktext=x,
    layout_xaxis_tickvals=x,
    **style,
)
fig.add_scatter(x=x, y=y)

fig.write_image('plots/inference-time.png')
fig.write_image('plots/inference-time.pdf')

fig