# Onnx vs PyTorch example

This example loads an event inspired by ProtoDUNE-SP simulated data and puts it
through the `DUNEdn` denoising pipeline.

The models implemented in PyTorch are exported to Onnx format and both are used
to make inference separately.

The outputs are then exploited to make accuracy and performance comparisons.

- **Accuracy**  
  Denoised events are analysed against ground truth labels from Monte
  Carlo simulation.  
  Four different metrics (namely `mse`, `psnr`, `ssim` and `iMAE`) are evaluated
  as in the [paper](https://doi.org/10.1007/s41781-021-00077-9).
- **Performance**  
  The elapsed time for PyTorch and Onnx models batch prediction is measured for
  different batch sizes.

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
from plot_event_example import plot_example
from assets.functions import (
    check_in_output_folder,
    inference,
    compare_performance_onnx,
    plot_comparison_catplot,
)
from dunedn.inference.hitreco import DnModel
from dunedn.inference.analysis import analysis_main
from dunedn.utils.utils import load_runcard

Define user inputs.

The user might want to tweak the following variables to experiment with the `DUNEdn` package.

- **modeltype** -> available options: `cnn`, `gcnn`, `uscg`.
- **version**  -> available options: `v08`, `v09`  
  The dataset version where the model was trained on.  
  For `cnn` and `gcnn` networks, only version `v08` is available.
- **pytorch_dev** -> available options: `cpu`, `cuda:0` or `cuda:id`.  
  The device hosting the PyTorch computation.  
  It is recommended to run PyTorch on a GPU.  
  Default ``batch_size`` settings ensure that the computation fits a 16 GB gpu.  
- **base_folder** -> the output folder.  
  Ensure to have permissions to write on the device.

In [None]:
# user inputs
modeltype = "cnn"
version = "v08"
pytorch_dev = "cuda:0"  # device hosting PyTorch computation
base_folder = Path("../../output/tmp")

Set up the environment

In [None]:
# base folders
ckpt_folder = Path(f"../saved_models/{modeltype}_{version}")

# relative folders
folders = {
    "base": base_folder,
    "out": base_folder / "models/onnx",
    "ckpt": ckpt_folder,
    "cards": base_folder / f"cards",
    "onnx_save": base_folder / f"models/onnx/saved_models/{modeltype}_{version}",
    "plot": base_folder / "models/onnx/plots",
    "id_plot": base_folder / "models/onnx/plots/identity",
    "pytorch_plot": base_folder / "models/onnx/plots/torch",
    "onnx_plot": base_folder / "models/onnx/plots/onnx",
}

In [None]:
# path to files
paths = {
    "input": folders["out"] / "p2GeV_cosmics_inspired_rawdigit_evt8.npy",
    "target": folders["out"] / "p2GeV_cosmics_inspired_rawdigit_noiseoff_evt8.npy",
    "pytorch": folders["out"]
    / f"p2GeV_cosmics_inspired_rawdigit_torch_{modeltype}_evt8.npy",
    "onnx": folders["out"]
    / f"p2GeV_cosmics_inspired_rawdigit_onnx_{modeltype}_evt8.npy",
    "performance_csv": folders["out"] / f"{modeltype}_performance_comparison.csv",
}

In [None]:
check_in_output_folder(folders)

Plot an example

In [None]:
plot_example(
    paths["input"], paths["target"], outdir=folders["id_plot"], with_graphics=True
)

evt = np.load(paths["input"])[:, 2:]
print(f"Loaded event at {paths['input']}")

## Pytorch inference

### Model loading

In [None]:
setup = load_runcard(base_folder / "cards/runcard.yaml")  # settings
model = DnModel(setup, modeltype, ckpt_folder)
print(f"Loaded model from {ckpt_folder} folder")

### PyTorch inference

In [None]:
pytorch_time = inference(model, evt, paths["pytorch"], pytorch_dev)
print(f"PyTorch inference done in {pytorch_time}s")

### Analysis: accuracy assessment

In [None]:
# compute metrics
analysis_main(paths["pytorch"], paths["target"])

In [None]:
# make plot
plot_example(
    paths["pytorch"],
    paths["target"],
    outdir=folders["pytorch_plot"],
    with_graphics=True,
)

## Onnx inference

### Export to and load model from Onnx format

In [None]:
# export
model.onnx_export(folders["onnx_save"])

In [None]:
# load model
model_onnx = DnModel(setup, modeltype, folders["onnx_save"], should_use_onnx=True)

### Onnx inference

In [None]:
# ONNX inference
onnx_time = inference(model_onnx, evt, paths["onnx"])
print(f"ONNX inference done in {onnx_time}s")

### Analysis: accuracy assessment

In [None]:
# compute metrics
analysis_main(paths["onnx"], paths["target"])

In [None]:
# make plot
plot_example(
    paths["onnx"], paths["target"], outdir=folders["onnx_plot"], with_graphics=True
)

## PyTorch vs Onnx Performance

The goal is to compare the performance of the models for different input batch sizes.

The collected inference timings are loaded into a `Pandas.Dataframe` for easier manipulation.

In [None]:
batch_size_list = [32, 64, 128, 256, 512]
nb_batches = 5
performance_df = compare_performance_onnx(
    model, model_onnx, pytorch_dev, batch_size_list, nb_batches
)
performance_df.to_csv(paths["performance_csv"])

In [None]:
performance_df = pd.read_csv(paths["performance_csv"]).set_index(["batch", "value"])
plot_comparison_catplot(performance_df, folders["plot"], with_graphics=True)