# Onnx vs PyTorch example

This example loads an event inspired by ProtoDUNE-SP simulated data and puts it
through the `DUNEdn` denoising pipeline.

The models implemented in PyTorch are exported to Onnx format and both are used
to make inference separately.

Denoised events are analysed against ground truth labels from Monte Carlo
simulation.  
Four different metrics (namely `mse`, `psnr`, `ssim` and `iMAE`) are evaluated
as in the [paper](https://doi.org/10.1007/s41781-021-00077-9).

The image below shows the pipeline adopted for the current example. 

<div style="text-align:center">
<img src="assets/accuracy.png" alt="Onnx accuracy example" width=60%/>
</div>

In [None]:
from pathlib import Path
import numpy as np
from plot_event_example import plot_example
from assets.functions import (
    prepare_folders_and_paths,
    check_in_output_folder,
    inference,
)
from dunedn.inference.hitreco import DnModel
from dunedn.inference.analysis import analysis_main
from dunedn.utils.utils import load_runcard

Define user inputs.

The user might want to tweak the following variables to experiment with the `DUNEdn` package.

- **modeltype** -> available options: `cnn`, `gcnn`, `uscg`.
- **version**  -> available options: `v08`, `v09`  
  The dataset version where the model was trained on.  
  For `cnn` and `gcnn` networks, only version `v08` is available.
- **pytorch_dev** -> available options: `cpu`, `cuda:0` or `cuda:id`.  
  The device hosting the PyTorch computation.  
  It is recommended to run PyTorch on a GPU.  
  Default ``batch_size`` settings ensure that the computation fits a 16 GB gpu.  
- **base_folder** -> the output folder.  
  Ensure to have permissions to write on the device.
- **ckpt_folder** -> the checkpoint folder.  
  Ensure the folder has the structure explained in the package documentation.

In [None]:
# user inputs
modeltype = "cnn"
version = "v08"
pytorch_dev = "cpu"  # device hosting PyTorch computation
base_folder = Path("../../output/tmp")
ckpt_folder = Path(f"../saved_models/{modeltype}_{version}")

# set up the environment
folders, paths = prepare_folders_and_paths(modeltype, version, base_folder, ckpt_folder)

Create output directories

In [None]:
check_in_output_folder(folders)

Plot an example

In [None]:
plot_example(
    paths["input"], paths["target"], outdir=folders["id_plot"], with_graphics=True
)

evt = np.load(paths["input"])[:, 2:]
print(f"Loaded event at {paths['input']}")

Model loading: PyTorch

In [None]:
setup = load_runcard(base_folder / "cards/runcard.yaml")  # settings
model = DnModel(setup, modeltype, ckpt_folder)
print(f"Loaded model from {ckpt_folder} folder")

Model loading: Onnx

In [None]:
# export
model.onnx_export(folders["onnx_save"])

In [None]:
# load model
model_onnx = DnModel(setup, modeltype, folders["onnx_save"], should_use_onnx=True)
print(f"Loaded model from {folders['onnx_save']} folder")

## Pytorch inference

In [None]:
pytorch_time = inference(model, evt, paths["pytorch"], pytorch_dev)
print(f"PyTorch inference done in {pytorch_time}s")

### Analysis: accuracy assessment

In [None]:
# compute metrics
analysis_main(paths["pytorch"], paths["target"])

In [None]:
# make plot
plot_example(
    paths["pytorch"],
    paths["target"],
    outdir=folders["pytorch_plot"],
    with_graphics=True,
)

## Onnx inference

In [None]:
onnx_time = inference(model_onnx, evt, paths["onnx"])
print(f"ONNX inference done in {onnx_time}s")

### Analysis: accuracy assessment

In [None]:
# compute metrics
analysis_main(paths["onnx"], paths["target"])

In [None]:
# make plot
plot_example(
    paths["onnx"], paths["target"], outdir=folders["onnx_plot"], with_graphics=True
)