In [1]:
from time import time as tm
from pathlib import Path
import numpy as np
from dunedn.inference.hitreco import DnModel

In [2]:
modeltype = "cnn"
version = "v08"
outdir = Path("../benchmarks/onnx/onnx_benchmark")
ckpt = Path(f"../saved_models/{modeltype}_{version}") # folder with checkpoints

In [3]:
# data loading

# print(f"Extracting data into {outdir}...")
# !mkdir -p $outdir
# tar -xf dunetpc_inspired_v09_p2GeV_rawdigits.tar.gz -C $outdir

fname = outdir / "p2GeV_cosmics_inspired_rawdigit_evt8.npy"
evt = np.load(fname)
print(f"Loaded event at {fname}")

Loaded event at ../benchmarks/onnx/onnx_benchmark/p2GeV_cosmics_inspired_rawdigit_evt8.npy


In [4]:
# inference function
def inference(model, evt, fname):
    """Makes inference on event and computes time.

    Saves the output file to `fname`.

    Parameters
    ----------
    model: DnModel
        The pytorch or onnx based model.
    evt: np.ndarray
        The input raw data.
    fname: Path
        The output file name.
    
    Returns
    -------
    inference_time: float
        The elapsed time for inference.    
    """
    start = tm()
    evt_dn = model.predict(evt)
    inference_time = tm() - start

    # save pytorch inference
    np.save(fname, evt_dn)
    return inference_time

In [4]:
# PyTorch model loading
model = DnModel(modeltype, ckpt)
print(f"Loaded model from {ckpt} folder")

Loaded model from ../saved_models/gcnn_v08 folder


In [None]:
# PyTorch inference
fname = outdir / "pytorch_inference_results.npy"
pytorch_time = inference(model, evt, fname)
print(f"PyTorch inference done in {pytorch_time}s")

In [5]:
# uncomment this line to export model to ONNX format
# model.export_onnx(outdir / f"saved_models/{modeltype}_{version}")

model_onnx = DnModel(
    modeltype,
    outdir / f"saved_models/{modeltype}_{version}",
    should_use_onnx=True
)



In [6]:
# ONNX inference
fname = outdir / "onnx_inference_results.npy"
onnx_time = inference(model_onnx, evt, fname)
print(f"PyTorch inference done in {onnx_time}s")

  1%|          | 1/156 [00:49<2:09:01, 49.95s/it]