This notebook is responsible for exporting the MLPF trained model from pytorch to ONNX.

In [None]:
# !pip install onnxscript
# !pip install onnxconverter-common

In [None]:
import matplotlib
import copy
import pickle as pkl
import sys
import numpy as np
from tqdm import tqdm
import tensorflow_datasets as tfds
import math

import numba
import awkward
import vector
import fastjet
import matplotlib as mpl
import matplotlib.pyplot as plt
import mplhep

import boost_histogram as bh
import mplhep

import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch import Tensor

import onnxscript
import onnx
import onnxruntime as rt
from onnxconverter_common import float16
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat

sys.path.append("../../")
import mlpf
from mlpf.model.mlpf import MLPF
from mlpf.model.utils import unpack_predictions, unpack_target
from mlpf.jet_utils import match_jets, to_p4_sph
from mlpf.plotting.plot_utils import cms_label, sample_label, ELEM_NAMES_CMS

In [None]:
mplhep.style.use("CMS")

In [None]:
#check which onnxruntime we are using. For CUDA, we must use onnxruntime-gpu (not onnxruntime)
rt.__path__, rt.__version__

In [None]:
torch.__version__

In [None]:
# contrib op: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmultiheadattention
# CMSSW ONNXRuntime version: https://github.com/cms-sw/cmsdist/blob/REL/CMSSW_14_1_0_pre3/el9_amd64_gcc12/onnxruntime.spec
# ONNXRuntime compatiblity table: https://onnxruntime.ai/docs/reference/compatibility.html

#with pytorch 2.5.0, we should use at least opset 20 (previous opsets did not work)
from onnxscript import opset20 as op
opset_version = 20

custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
msft_op = onnxscript.values.Opset("com.microsoft", 1)

In [None]:
!ls /home/joosep/particleflow/experiments/pyg-cms_20250722_101813_274478/checkpoints

In [None]:
#tfds datasets are here:
data_dir = "/scratch/persistent/joosep/tensorflow_datasets/"
dataset = "cms_pf_ttbar"

#model checkpoints are here:
outdir = "/home/joosep/particleflow/experiments/pyg-cms_20250722_101813_274478/"

#Load model weights from existing training
model_state = torch.load(
    outdir + "/checkpoints/checkpoint-10-3.812332.pth", map_location=torch.device("cpu"), weights_only=True
)
with open(f"{outdir}/model_kwargs.pkl", "rb") as f:
    model_kwargs = pkl.load(f)


#this is needed to configure com.microsoft.MultiHeadAttention
NUM_HEADS = model_kwargs["num_heads"]

#set this to cuda if you are running the notebook on a GPU, otherwise use cpu
torch_device = "cpu"
if torch_device == "cpu":
    model_kwargs["attention_type"] = "math"

In [None]:
model = MLPF(**model_kwargs, use_simplified_attention=False, export_onnx_fused=False, save_attention=True)
model.eval()
model.load_state_dict(model_state["model_state_dict"], strict=False)
model = model.to(device=torch_device)

In [None]:
model_simple_unfused = MLPF(**model_kwargs, use_simplified_attention=True, export_onnx_fused=False)
model_simple_unfused.eval()
model_simple_unfused.load_state_dict(model_state["model_state_dict"], strict=False)
model_simple_unfused = model_simple_unfused.to(device=torch_device)

In [None]:
model_simple_fused = MLPF(**model_kwargs, use_simplified_attention=True, export_onnx_fused=True)
model_simple_fused.eval()
model_simple_fused.load_state_dict(model_state["model_state_dict"], strict=False)
model_simple_fused = model_simple_fused.to(device=torch_device)

In [None]:
!rm -f *.onnx

In [None]:
#the values in here are not important, it just needs some kind of shapes
dummy_features = torch.randn(1, 256, model_kwargs["input_dim"]).float()
dummy_mask = (torch.randn(1, 256)>0.5).float()

In [None]:
(model(dummy_features, dummy_mask)[1] - model_simple_unfused(dummy_features, dummy_mask)[1]).abs().sum()

In [None]:
#export the ONNX model with naive (unfused) attention
torch.onnx.export(
    model_simple_unfused,
    (dummy_features, dummy_mask),
    "test_fp32_unfused.onnx",
    opset_version=opset_version,
    verbose=False,
    input_names=[
        "Xfeat_normed", "mask",
    ],
    output_names=["bid", "id", "momentum", "pu"],
    dynamic_axes={
        "Xfeat_normed": {0: "num_batch", 1: "num_elements"},
        "mask": {0: "num_batch", 1: "num_elements"},
        "bid": {0: "num_batch", 1: "num_elements"},
        "id": {0: "num_batch", 1: "num_elements"},
        "momentum": {0: "num_batch", 1: "num_elements"},
        "pu": {0: "num_batch", 1: "num_elements"},
    },
)

In [None]:
#register our custom op that calls out to the fast MultiHeadAttention implementation
@onnxscript.script(custom_opset)
def SDPA(
    query: TFloat,
    key: TFloat,
    value: TFloat,
) -> TFloat:

    # Unlike pytorch scaled_dot_product_attention,
    # the input here MUST BE (batch, seq_len, num_head*head_dim).
    # Also, for the op to be fast on GPU, it needs to run in float16.
    query = op.Cast(query, to=onnx.TensorProto.FLOAT16)
    key = op.Cast(key, to=onnx.TensorProto.FLOAT16)
    value = op.Cast(value, to=onnx.TensorProto.FLOAT16)
    output, _, _ = msft_op.MultiHeadAttention(query, key, value, num_heads=NUM_HEADS)
    output = op.Cast(output, to=onnx.TensorProto.FLOAT)

    return output


# setType API provides shape/type to ONNX shape/type inference
# function signature must match pytorch aten::scaled_dot_product_attention from
# https://github.com/pytorch/pytorch/blob/16676fd17b10b06e692656bbba8db5e0d6052a20/aten/src/ATen/native/transformers/attention.cpp#L699
def custom_scaled_dot_product_attention(
    g, query: TFloat, key: TFloat, value: TFloat, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
):
    return g.onnxscript_op(SDPA, query, key, value).setType(query.type())


#the warning 'MultiHeadAttention' is not a known op in 'com.microsoft' is not actually important
print("registering custom op for scaled_dot_product_attention")
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::scaled_dot_product_attention",
    symbolic_fn=custom_scaled_dot_product_attention,
    opset_version=opset_version,
)

torch.onnx.export(
    model_simple_fused,
    (dummy_features, dummy_mask),
    "test_fp32_fused.onnx",
    opset_version=opset_version,
    verbose=False,
    input_names=[
        "Xfeat_normed", "mask",
    ],
    output_names=["bid", "id", "momentum", "pu"],
    dynamic_axes={
        "Xfeat_normed": {0: "num_batch", 1: "num_elements"},
        "mask": {0: "num_batch", 1: "num_elements"},
        "bid": {0: "num_batch", 1: "num_elements"},
        "id": {0: "num_batch", 1: "num_elements"},
        "momentum": {0: "num_batch", 1: "num_elements"},
        "pu": {0: "num_batch", 1: "num_elements"},
    },
)

In [None]:
print("Available ONNX runtime providers:", rt.get_available_providers())
sess_options = rt.SessionOptions()
sess_options.intra_op_num_threads = 32  # need to explicitly set this to get rid of onnxruntime error

sess_options.log_severity_level = 1
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

execution_provider = "CPUExecutionProvider"
if torch_device == "cuda":
    execution_provider = "CUDAExecutionProvider"

In [None]:
onnx_sess_unfused = rt.InferenceSession("test_fp32_unfused.onnx", sess_options, providers=[execution_provider])

In [None]:
onnx_sess_fused = rt.InferenceSession("test_fp32_fused.onnx", sess_options, providers=[execution_provider])

In [None]:
def diffs_vec(pred_reference, pred_test):
    diffs = [torch.mean(torch.abs(torch.flatten(pred_reference[i]-pred_test[i]))).item() for i in range(len(pred_test))]
    return diffs

#cluster particles to jets, return jet pt
def particles_to_jets(pred, mask):
    jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
    ypred = unpack_predictions(pred)
    for k, v in ypred.items():
        ypred[k] = v[mask].detach().cpu().contiguous().numpy()
    
    counts = torch.sum(mask, axis=1).cpu().numpy()
    clsid = awkward.unflatten(ypred["cls_id"], counts)
    msk = clsid != 0
    p4 = awkward.unflatten(ypred["p4"], counts)
    
    vec = vector.awk(
        awkward.zip(
            {
                "pt": p4[msk][:, :, 0],
                "eta": p4[msk][:, :, 1],
                "phi": p4[msk][:, :, 2],
                "e": p4[msk][:, :, 3],
            }
        )
    )
    cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
    jets = cluster.inclusive_jets(min_pt=3)
    return jets

In [None]:
#loop over the dataset, run the 4 different models, save outputs
builder = tfds.builder(dataset, data_dir=data_dir)
ds = builder.as_data_source(split="train")

max_events = 10
events_per_batch = 1
inds = range(0, max_events, events_per_batch)

jets_mlpf = []
jets_mlpf_simple = []
jets_onnx_unfused = []
jets_onnx_fused = []

model = model.to(torch_device)

preds = []
preds_onnx_fused = []
X_feat_list = []
y_target_list = []
for ind in tqdm(inds):
    ds_elems = [ds[i] for i in range(ind,ind+events_per_batch)]
    X_features = [torch.tensor(elem["X"]).to(torch.float32).to(torch_device) for elem in ds_elems]
    y_targets = [torch.tensor(elem["ytarget"]).to(torch.float32).to(torch_device) for elem in ds_elems]
    X_feat_list += X_features
    y_target_list += y_targets

    #batch the data into [batch_size, num_elems, num_features]
    X_features_padded = pad_sequence(X_features, batch_first=True).contiguous()
    y_targets_padded = pad_sequence(y_targets, batch_first=True).contiguous()
    # print("batch", ind, X_features_padded.shape)
    mask = X_features_padded[:, :, 0]!=0
    mask_f = mask.float()

    with torch.no_grad():
        # print("running base pytorch model")
        pred = model(X_features_padded.to(torch_device), mask.to(torch_device))
        pred = tuple(pred[x].cpu() for x in range(len(pred)))
        preds.append(pred)

    j0 = particles_to_jets(pred, mask.cpu())
    jets_mlpf.append(j0)

    # print("running ONNX unfused model")
    pred_onnx_unfused = onnx_sess_unfused.run(["bid", "id", "momentum", "pu"], {"Xfeat_normed": X_features_padded.cpu().numpy(), "mask": mask_f.cpu().numpy()})
    pred_onnx_unfused = tuple(torch.tensor(p) for p in pred_onnx_unfused)
    j2 = particles_to_jets(pred_onnx_unfused, mask.cpu())
    jets_onnx_unfused.append(j2)

    # print("running ONNX fused model")
    pred_onnx_fused = onnx_sess_fused.run(["bid", "id", "momentum", "pu"], {"Xfeat_normed": X_features_padded.cpu().numpy(), "mask": mask_f.cpu().numpy()})
    pred_onnx_fused = tuple(torch.tensor(p) for p in pred_onnx_fused)
    preds_onnx_fused.append(pred_onnx_fused)
    
    j3 = particles_to_jets(pred_onnx_fused, mask.cpu())
    jets_onnx_fused.append(j3)

    for conv in model.conv_id + model.conv_reg:
        conv.att_mat_idx += 1

In [None]:
def sum_overflow_into_last_bin(all_values):
    values = all_values[1:-1]
    values[-1] = values[-1] + all_values[-1]
    values[0] = values[0] + all_values[0]
    return values

def to_bh(data, bins, cumulative=False):
    h1 = bh.Histogram(bh.axis.Variable(bins))
    h1.fill(data)
    if cumulative:
        h1[:] = np.sum(h1.values()) - np.cumsum(h1)
    h1[:] = sum_overflow_into_last_bin(h1.values(flow=True)[:])
    return h1

In [None]:
#There can be cases where different inference modes produce very slightly different numbers of jets due to floating point differences.
#Therefore, we match jet pairs based on delta-R, and compare the pT of matched jets.
match_inds1, match_inds2 = match_jets(to_p4_sph(awkward.concatenate(jets_mlpf)), to_p4_sph(awkward.concatenate(jets_onnx_unfused)), 0.001)
match_inds3, match_inds4 = match_jets(to_p4_sph(awkward.concatenate(jets_mlpf)), to_p4_sph(awkward.concatenate(jets_onnx_fused)), 0.001)

In [None]:
b = np.logspace(0, 2, 200)
plt.figure(figsize=(6,5))
plt.hist2d(
    awkward.to_numpy(awkward.flatten(awkward.concatenate(jets_mlpf)[match_inds1].pt)),
    awkward.to_numpy(awkward.flatten(awkward.concatenate(jets_onnx_unfused)[match_inds2].pt)),
    bins=b,
    norm=mpl.colors.LogNorm(),
    cmap="Reds"
);
plt.xlabel("jet $p_{\mathrm{T,pytorch}}$")
plt.ylabel("jet $p_{\mathrm{T,pytorch,simple}}$")
plt.xscale("log")
plt.yscale("log")
plt.colorbar()

In [None]:
b = np.logspace(0, 2, 200)
plt.figure(figsize=(6,5))
plt.hist2d(
    awkward.to_numpy(awkward.flatten(awkward.concatenate(jets_mlpf)[match_inds3].pt)),
    awkward.to_numpy(awkward.flatten(awkward.concatenate(jets_onnx_fused)[match_inds4].pt)),
    bins=b,
    norm=mpl.colors.LogNorm(),
    cmap="Reds"
);
plt.xlabel("jet $p_{\mathrm{T,pytorch,simple}}$")
plt.ylabel("jet $p_{\mathrm{T,ONNX}}$")
plt.xscale("log")
plt.yscale("log")
plt.colorbar()

In [None]:
b = np.linspace(0.9,1.1, 500)
plt.figure(figsize=(10, 10))
ax = plt.axes()
plt.hist(
    awkward.flatten(awkward.concatenate(jets_mlpf)[match_inds3].pt/awkward.concatenate(jets_onnx_fused)[match_inds4].pt),
    bins=b, histtype="step");
plt.yscale("log")
plt.xlabel("jet $p_{\mathrm{T,ONNX,fused}}/p_{\mathrm{T,pytorch}}$")
plt.ylabel("Number of matched jets")
cms_label(ax)
sample_label(ax, dataset)
plt.savefig("pytorch_onnx_jet_ratio.pdf")

In [None]:
f, (a0, a1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": [4, 1]}, sharex=True, figsize=(10, 10))

plt.sca(a0)

b = np.logspace(0,2,101)
h0 = to_bh(awkward.flatten(awkward.concatenate(jets_mlpf).pt), bins=b)
h1 = to_bh(awkward.flatten(awkward.concatenate(jets_onnx_unfused).pt), bins=b)
h2 = to_bh(awkward.flatten(awkward.concatenate(jets_onnx_fused).pt), bins=b)

mplhep.histplot(h0, label="pytorch", lw=0.5, yerr=0)
mplhep.histplot(h1, label="ONNX, unfused", lw=0.5, yerr=0)
mplhep.histplot(h2, label="ONNX, fused", lw=0.5, yerr=0)

plt.legend()
plt.yscale("log")
plt.xscale("log")
plt.ylabel("Number of jets")
cms_label(a0)
sample_label(a0, dataset)

plt.sca(a1)
mplhep.histplot(h0/h0, label="pytorch", lw=0.5, yerr=0)
mplhep.histplot(h1/h0, label="ONNX, fused", lw=0.5, yerr=0)
mplhep.histplot(h2/h0, label="ONNX, unfused", lw=0.5, yerr=0)
plt.ylim(0.5, 1.5)
plt.xlim(1, 100)
plt.ylabel("vs. pytorch")
plt.xlabel("jet $p_T$ [GeV]")
plt.savefig("pytorch_onnx_jet_pt.pdf")

In [None]:
plt.figure(figsize=(6,5))
att_mat_reg2 = np.load("attn_conv_reg_2_0.npz")["att"]
plt.imshow(att_mat_reg2[0], cmap="hot_r", norm=matplotlib.colors.LogNorm(vmin=1e-6, vmax=1))
plt.colorbar()

In [None]:
import sklearn
import sklearn.manifold

In [None]:
model_kwargs["num_convs"]

In [None]:
#take the intermediate outputs of the first event for all layers
att_filenames = (
    ["attn_conv_reg_{}_0".format(ilayer) for ilayer in range(model_kwargs["num_convs"])] +
    ["attn_conv_id_{}_0".format(ilayer) for ilayer in range(model_kwargs["num_convs"])] 
    )


#separate elements by type
typs = [
    1, #track
    4,5, #calo
    6, #FSG

    #HF
    8,9,

    #don't visualize superclusters and HO for simplicity
    #10,11
]

#take the data for the first event
X = X_feat_list[0].numpy()
y = y_target_list[0].numpy()

energy_marker_sizes = np.linspace(1,30,10)
energy_bins = np.logspace(0,3,10)
energy_markers = energy_marker_sizes[np.searchsorted(energy_bins, X[:, 5])]

fig, axs = plt.subplots(2, model_kwargs["num_convs"], figsize=(5*model_kwargs["num_convs"],2*5))
axs = axs.flatten()
iattn = 0
for layertype in ["reg", "id"]:
    for layernum in range(model_kwargs["num_convs"]):
        att_fn = "attn_conv_{}_{}_0.npz".format(layertype, layernum)
        att_file = np.load(att_fn)
        tsne = sklearn.manifold.TSNE()
        embed_2d = tsne.fit_transform(att_file["x"][0])
        ax = axs[iattn]
        plt.sca(axs[iattn])
        for typ in typs:
            msk = X[:, 0]==typ
            alpha = np.ones(len(msk), dtype=np.float32)
            alpha[y[:, 0]==0] = 0.2
            plt.scatter(
                embed_2d[msk, 0],
                embed_2d[msk, 1],
                label=ELEM_NAMES_CMS[typ],
                alpha=alpha[msk],
                s=energy_markers[msk]
            )
        ax.set_xticks([])
        ax.set_yticks([])
        plt.legend(fontsize=12, ncols=2)
        plt.title(r"$z_{{{}}}$".format(layertype+str(layernum)), fontsize=12)
        plt.xlim(-150,150)
        plt.ylim(-150,150)
        iattn += 1

In [None]:
X_etaphi = [X[:, 2:5].numpy() for X in X_feat_list]

In [None]:
from numba import njit

@njit
def calculate_delta_r(particles):
    """
    Calculates the Delta R matrix for a set of particles.

    Args:
        particles (np.ndarray): An N x 3 matrix where N is the number of particles.
                                Each row contains [eta, sin(phi), cos(phi)].

    Returns:
        np.ndarray: An N x N matrix containing the Delta R values between
                    each pair of particles.
    """
    n_particles = particles.shape[0]
    delta_r_matrix = np.empty((n_particles, n_particles), dtype=np.float64)

    # Calculate phi for all particles first
    phis = np.empty(n_particles, dtype=np.float64)
    for i in range(n_particles):
        phis[i] = np.arctan2(particles[i, 1], particles[i, 2]) # atan2(sin(phi), cos(phi))

    for i in range(n_particles):
        for j in range(n_particles):
            if i == j:
                delta_r_matrix[i, j] = 0.0
                continue

            eta1 = particles[i, 0]
            phi1 = phis[i]

            eta2 = particles[j, 0]
            phi2 = phis[j]

            delta_eta = eta1 - eta2
            
            delta_phi = phi1 - phi2
            delta_phi = np.arctan2(np.sin(delta_phi), np.cos(delta_phi))

            delta_r_matrix[i, j] = np.sqrt(delta_eta**2 + delta_phi**2)
            
    return delta_r_matrix

In [None]:
dr_mat = calculate_delta_r(X_etaphi[0])

In [None]:
def plot_attention_dr(layertype):
    fig, axes = plt.subplots(2, model_kwargs["num_convs"], figsize=(model_kwargs["num_convs"]*6, 2*6), constrained_layout=True, gridspec_kw={'height_ratios': [1, 1]})
    for nlayer in range(model_kwargs["num_convs"]):
        att_mat = np.load("attn_conv_{}_{}_0.npz".format(layertype, nlayer))["att"]
    
        # Plotting parameters
        imshow_norm = matplotlib.colors.LogNorm(vmin=1e-6, vmax=1)
        hist2d_bins = (np.logspace(-2, 2, 100), np.logspace(-10, 0, 100))
        cmap_choice = "hot_r"
        
        # Top row: Attention Matrices
        ax = axes[0, nlayer]
        im = ax.imshow(att_mat[0], cmap=cmap_choice, norm=imshow_norm)
        #fig.colorbar(im, ax=ax, fraction=0.046)
        ax.set_title("self-attention $A^{" + (layertype + str(nlayer)) + "}_{ij}$")
        ax.set_xlabel("elem. i")
        ax.set_ylabel("elem. j")
        ax.set_xticks([], [])
        ax.set_yticks([], [])
        
        
        # Bottom row: DR-Attention Correlations
        ax = axes[1, nlayer]
        
        counts, xedges, yedges, im = ax.hist2d(
            dr_mat.flatten(),
            att_mat.flatten(),
            bins=hist2d_bins,
            cmap=cmap_choice,
            norm=matplotlib.colors.LogNorm(vmin=1, vmax=1e6)
        )
        #fig.colorbar(im, ax=ax, fraction=0.046)
        ax.set_yscale("log")
        ax.set_xscale("log")
        #ax.set_title(titles_corr[i])
        ax.set_xlabel("$\Delta R_{ij}$")
        ax.set_ylabel("$A_{ij}$")
        ax.set_box_aspect(1)
    plt.tight_layout()

In [None]:
plot_attention_dr("id")

In [None]:
plot_attention_dr("reg")