In [None]:
import copy
import pickle as pkl
import sys
import numpy as np
import tqdm
import tensorflow_datasets as tfds

import awkward
import vector
import fastjet
import matplotlib.pyplot as plt

import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch import Tensor

import onnxscript
import onnx
import onnxruntime as rt
from onnxconverter_common import float16
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat

sys.path.append("../../mlpf")
from pyg.mlpf import MLPF
from pyg.utils import unpack_predictions, unpack_target

In [None]:
from onnxscript import opset17 as op
opset_version = 17
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
msft_op = onnxscript.values.Opset("com.microsoft", 1)

In [None]:
#tfds datasets are here:
data_dir = "/scratch/persistent/joosep/tensorflow_datasets/"
dataset = "cms_pf_ttbar"

#model checkpoints are here:
outdir = "../../experiments/pyg-cms_20240430_094836_751206"

#Load model arguments from existing training
model_state = torch.load(
    outdir + "/checkpoints/checkpoint-25-17.631161.pth", map_location=torch.device("cpu")
)
with open(f"{outdir}/model_kwargs.pkl", "rb") as f:
    model_kwargs = pkl.load(f)

#this is needed to configure com.microsoft.MultiHeadAttention
NUM_HEADS = model_kwargs["num_heads"]

In [None]:
#Load model from our codebase
model = MLPF(**model_kwargs)
model.eval()
model.load_state_dict(model_state["model_state_dict"])

In [None]:
#these are copied from mlpf.py for explicit clarity
def get_activation(activation):
    if activation == "elu":
        act = nn.ELU
    elif activation == "relu":
        act = nn.ReLU
    elif activation == "relu6":
        act = nn.ReLU6
    elif activation == "leakyrelu":
        act = nn.LeakyReLU
    return act

def ffn(input_dim, output_dim, width, act, dropout):
    return nn.Sequential(
        nn.Linear(input_dim, width),
        act(),
        torch.nn.LayerNorm(width),
        nn.Dropout(dropout),
        nn.Linear(width, output_dim),
    )

class RegressionOutput(nn.Module):
    def __init__(self, mode, embed_dim, width, act, dropout, elemtypes):
        super().__init__()
        self.mode = mode
        self.elemtypes = elemtypes
        self.nn = ffn(embed_dim, 2, width, act, dropout)

    def forward(self, elems, x, orig_value):
        nn_out = self.nn(x)
        nn_out = orig_value * nn_out[..., 0:1] + nn_out[..., 1:2]
        return nn_out

class SimpleMultiheadAttention(nn.MultiheadAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        bias = True
        batch_first = True
        super().__init__(embed_dim, num_heads, dropout, bias=bias, batch_first=batch_first, **factory_kwargs)
        self.head_dim = int(embed_dim // num_heads)
        
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
        self.export_onnx = False

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:

        bs, seq_len, embed_dim = q.size()
        head_dim = self.head_dim
        num_heads = self.num_heads

        wq, wk, wv = torch.split(self.in_proj_weight.data, [self.embed_dim, self.embed_dim, self.embed_dim], dim=0)
        bq, bk, bv = torch.split(self.in_proj_bias.data, [self.embed_dim, self.embed_dim, self.embed_dim], dim=0)

        q = torch.matmul(q, wq.T) + bq
        k = torch.matmul(k, wk.T) + bk
        v = torch.matmul(v, wv.T) + bv

        if not self.export_onnx:
            q = q.reshape(bs, seq_len, num_heads, head_dim).transpose(1,2).reshape(bs*num_heads, seq_len, head_dim)
            k = k.reshape(bs, seq_len, num_heads, head_dim).transpose(1,2).reshape(bs*num_heads, seq_len, head_dim)
            v = v.reshape(bs, seq_len, num_heads, head_dim).transpose(1,2).reshape(bs*num_heads, seq_len, head_dim)

        #this function will have different shape signatures in native torch and in ONNX com.microsoft.MultiHeadAttention
        attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout)
        
        if not self.export_onnx:
            attn_output = attn_output.reshape(bs, num_heads, seq_len, head_dim).transpose(1,2).reshape(bs, seq_len, num_heads*head_dim)
        
        assert list(attn_output.size()) == [bs, seq_len, num_heads * head_dim]
        attn_output = self.out_proj(attn_output)
        return attn_output, None

class SimpleSelfAttentionLayer(nn.Module):
    def __init__(
        self,
        activation="elu",
        embedding_dim=128,
        num_heads=2,
        width=128,
        dropout_mha=0.1,
        dropout_ff=0.1,
    ):
        super().__init__()

        self.act = get_activation(activation)
        self.mha = SimpleMultiheadAttention(embedding_dim, num_heads, dropout=dropout_mha)
        self.norm0 = torch.nn.LayerNorm(embedding_dim)
        self.norm1 = torch.nn.LayerNorm(embedding_dim)
        self.seq = torch.nn.Sequential(
            nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act()
        )
        self.dropout = torch.nn.Dropout(dropout_ff)

    def forward(self, x: Tensor, mask: Tensor):
        mha_out = self.mha(x, x, x)[0]

        x = x + mha_out
        x = self.norm0(x)
        x = x + self.seq(x)
        x = self.norm1(x)
        x = self.dropout(x)
        x = x * mask.unsqueeze(-1)
        return x

class SimpleMLPF(nn.Module):
    def __init__(
        self,
        input_dim=34,
        num_classes=8,
        embedding_dim=128,
        width=128,
        num_convs=2,
        dropout_ff=0.0,
        activation="elu",
        layernorm=True,
        # element types which actually exist in the dataset
        elemtypes_nonzero=[1, 4, 5, 6, 8, 9, 10, 11],
        # self-attention specific parameters
        num_heads=16,
        head_dim=16,
        dropout_conv_reg_mha=0.0,
        dropout_conv_reg_ff=0.0,
        dropout_conv_id_mha=0.0,
        dropout_conv_id_ff=0.0,
    ):
        super().__init__()

        self.act = get_activation(activation)

        self.input_dim = input_dim
        self.num_convs = num_convs

        self.elemtypes_nonzero = elemtypes_nonzero

        embedding_dim = num_heads * head_dim
        width = num_heads * head_dim

        # embedding of the inputs
        self.nn0_id = ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff)
        self.nn0_reg = ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff)

        self.conv_id = nn.ModuleList()
        self.conv_reg = nn.ModuleList()

        for i in range(num_convs):
            self.conv_id.append(
                SimpleSelfAttentionLayer(
                    activation=activation,
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    width=width,
                    dropout_mha=dropout_conv_id_mha,
                    dropout_ff=dropout_conv_id_ff,
                )
            )
            self.conv_reg.append(
                SimpleSelfAttentionLayer(
                    activation=activation,
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    width=width,
                    dropout_mha=dropout_conv_reg_mha,
                    dropout_ff=dropout_conv_reg_ff,
                )
            )

        decoding_dim = self.input_dim + embedding_dim

        # DNN that acts on the node level to predict the PID
        self.nn_id = ffn(decoding_dim, num_classes, width, self.act, dropout_ff)

        # elementwise DNN for node momentum regression
        embed_dim = decoding_dim + num_classes
        self.nn_pt = RegressionOutput("linear", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
        self.nn_eta = RegressionOutput("linear", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
        self.nn_sin_phi = RegressionOutput("linear", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
        self.nn_cos_phi = RegressionOutput("linear", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)
        self.nn_energy = RegressionOutput("linear", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)

    # @torch.compile
    def forward(self, X_features, mask):
        Xfeat_normed = X_features

        embeddings_id, embeddings_reg = [], []
        embedding_id = self.nn0_id(Xfeat_normed)
        embedding_reg = self.nn0_reg(Xfeat_normed)

        for num, conv in enumerate(self.conv_id):
            conv_input = embedding_id if num == 0 else embeddings_id[-1]
            out_padded = conv(conv_input, mask)
            embeddings_id.append(out_padded)
        for num, conv in enumerate(self.conv_reg):
            conv_input = embedding_reg if num == 0 else embeddings_reg[-1]
            out_padded = conv(conv_input, mask)
            embeddings_reg.append(out_padded)

        final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1)
        preds_id = self.nn_id(final_embedding_id)

        final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1)
        final_embedding_reg = torch.cat([Xfeat_normed] + [embeddings_reg[-1]] + [preds_id], axis=-1)

        # The PFElement feature order in X_features defined in fcc/postprocessing.py
        preds_pt = self.nn_pt(X_features, final_embedding_reg, X_features[..., 1:2])
        preds_eta = self.nn_eta(X_features, final_embedding_reg, X_features[..., 2:3])
        preds_sin_phi = self.nn_sin_phi(X_features, final_embedding_reg, X_features[..., 3:4])
        preds_cos_phi = self.nn_cos_phi(X_features, final_embedding_reg, X_features[..., 4:5])
        preds_energy = self.nn_energy(X_features, final_embedding_reg, X_features[..., 5:6])
        preds_momentum = torch.cat([preds_pt, preds_eta, preds_sin_phi, preds_cos_phi, preds_energy], axis=-1)

        return preds_id, preds_momentum

In [None]:
model_simple = SimpleMLPF(
        input_dim=model_kwargs["input_dim"],
        num_classes=model_kwargs["num_classes"],
        embedding_dim=model_kwargs["num_heads"]*model_kwargs["head_dim"],
        width=model_kwargs["num_heads"]*model_kwargs["head_dim"],
        num_convs=model_kwargs["num_convs"],
        dropout_ff=model_kwargs["dropout_ff"],
        activation=model_kwargs["activation"],
        layernorm=True,
        # element types which actually exist in the dataset
        elemtypes_nonzero=model_kwargs["elemtypes_nonzero"],
        # self-attention specific parameters
        num_heads=model_kwargs["num_heads"],
        head_dim=model_kwargs["head_dim"],
        dropout_conv_reg_mha=model_kwargs["dropout_conv_reg_mha"],
        dropout_conv_reg_ff=model_kwargs["dropout_conv_reg_ff"],
        dropout_conv_id_mha=model_kwargs["dropout_conv_id_mha"],
        dropout_conv_id_ff=model_kwargs["dropout_conv_id_ff"],
)
model_simple.eval();

In [None]:
model_simple.load_state_dict(model_state["model_state_dict"])

dummy_features = torch.randn(1, 256, model_kwargs["input_dim"]).float()
dummy_mask = torch.randn(1, 256).bool()

torch.onnx.export(
    model_simple,
    (dummy_features, dummy_mask),
    "test_fp32_unfused.onnx",
    opset_version=opset_version,
    verbose=False,
    input_names=[
        "Xfeat_normed", "mask",
    ],
    output_names=["id", "momentum"],
    dynamic_axes={
        "Xfeat_normed": {0: "num_batch", 1: "num_elements"},
        "mask": {0: "num_batch", 1: "num_elements"},
        "id": {0: "num_batch", 1: "num_elements"},
        "momentum": {0: "num_batch", 1: "num_elements"},
    },
)

In [None]:
model_simple_fused = copy.deepcopy(model_simple)
#configure the model to run in (batch, seq_len, num_heads*head_dim) 3d-mode.
for conv in model_simple_fused.conv_id + model_simple_fused.conv_reg:
    conv.mha.export_onnx = True

#register our custom op that calls out to the fast MultiHeadAttention implementation
@onnxscript.script(custom_opset)
def SDPA(
    query: TFloat,
    key: TFloat,
    value: TFloat,
) -> TFloat:

    # Unlike pytorch scaled_dot_product_attention,
    # the input here MUST BE (batch, seq_len, num_head*head_dim).
    # Also, for the op to be fast on GPU, it needs to run in float16.
    query = op.Cast(query, to=onnx.TensorProto.FLOAT16)
    key = op.Cast(key, to=onnx.TensorProto.FLOAT16)
    value = op.Cast(value, to=onnx.TensorProto.FLOAT16)
    output, _, _ = msft_op.MultiHeadAttention(query, key, value, num_heads=NUM_HEADS)
    output = op.Cast(output, to=onnx.TensorProto.FLOAT)

    return output


# setType API provides shape/type to ONNX shape/type inference
def custom_scaled_dot_product_attention(
    g, query: TFloat, key: TFloat, value: TFloat, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
):
    return g.onnxscript_op(SDPA, query, key, value).setType(query.type())


#the warning 'MultiHeadAttention' is not a known op in 'com.microsoft' is not actually important
print("registering custom op for scaled_dot_product_attention")
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::scaled_dot_product_attention",
    symbolic_fn=custom_scaled_dot_product_attention,
    opset_version=opset_version,
)

torch.onnx.export(
    model_simple_fused,
    (dummy_features, dummy_mask),
    "test_fp32_fused.onnx",
    opset_version=opset_version,
    verbose=False,
    input_names=[
        "Xfeat_normed", "mask",
    ],
    output_names=["id", "momentum"],
    dynamic_axes={
        "Xfeat_normed": {0: "num_batch", 1: "num_elements"},
        "mask": {0: "num_batch", 1: "num_elements"},
        "id": {0: "num_batch", 1: "num_elements"},
        "momentum": {0: "num_batch", 1: "num_elements"},
    },
)

sess_options = rt.SessionOptions()
onnx_sess_unfused = rt.InferenceSession("test_fp32_unfused.onnx", sess_options, providers=["CPUExecutionProvider"])
onnx_sess_fused = rt.InferenceSession("test_fp32_fused.onnx", sess_options, providers=["CPUExecutionProvider"])

In [None]:
def diffs_vec(preds):
    diffs = [torch.mean(torch.abs(torch.flatten(pred[i]-preds[i]))).item() for i in range(len(preds))]
    return diffs

def particles_to_jets(pred):
    jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
    ypred = unpack_predictions(pred)
    for k, v in ypred.items():
        ypred[k] = v[mask].detach().cpu().contiguous().numpy()
    
    counts = torch.sum(mask, axis=1).cpu().numpy()
    clsid = awkward.unflatten(ypred["cls_id"], counts)
    msk = clsid != 0
    p4 = awkward.unflatten(ypred["p4"], counts)
    
    vec = vector.awk(
        awkward.zip(
            {
                "pt": p4[msk][:, :, 0],
                "eta": p4[msk][:, :, 1],
                "phi": p4[msk][:, :, 2],
                "e": p4[msk][:, :, 3],
            }
        )
    )
    cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
    jets = cluster.inclusive_jets(min_pt=10)
    return awkward.to_numpy(awkward.flatten(jets.pt))

In [None]:
builder = tfds.builder(dataset, data_dir=data_dir)
ds = builder.as_data_source(split="test")
max_events = 50
events_per_batch = 1
inds = range(0, max_events, events_per_batch)

jets_mlpf = []
jets_mlpf_simple = []
jets_onnx_unfused = []
jets_onnx_fused = []

for ind in inds:
    ds_elems = [ds[i] for i in range(ind,ind+events_per_batch)]
    X_features = [torch.tensor(elem["X"]).to(torch.float32) for elem in ds_elems]
    y_targets = [torch.tensor(elem["ygen"]).to(torch.float32) for elem in ds_elems]

    #batch the data into [batch_size, num_elems, num_features]
    X_features_padded = pad_sequence(X_features, batch_first=True)
    y_targets_padded = pad_sequence(y_targets, batch_first=True)
    print("batch", ind, X_features_padded.shape)
    mask = X_features_padded[:, :, 0]!=0

    with torch.no_grad():
        print("running base model")
        pred = model(X_features_padded, mask)
        print("running simplified model")
        pred_simple = model_simple(X_features_padded, mask)

    pred = tuple(p.detach() for p in pred)
    jets_mlpf.append(particles_to_jets(pred))
    
    pred_simple = tuple(p.detach() for p in pred_simple)
    jets_mlpf_simple.append(particles_to_jets(pred_simple))
    
    torch.testing.assert_close(pred[0], pred_simple[0], atol=0.01, rtol=0.01)
    torch.testing.assert_close(pred[1], pred_simple[1], atol=0.01, rtol=0.01)
    
    diffs = diffs_vec(pred_simple)
    print("diffs: {:.4f} {:.4f}".format(*diffs))

    print("running ONNX unfused model")
    pred_onnx_unfused = onnx_sess_unfused.run(None, {"Xfeat_normed": X_features_padded.numpy(), "mask": mask.numpy()})
    pred_onnx_unfused = tuple(torch.tensor(p) for p in pred_onnx_unfused)
    jets_onnx_unfused.append(particles_to_jets(pred_onnx_unfused))
    diffs = diffs_vec(pred_onnx_unfused)
    print("diffs: {:.4f} {:.4f}".format(*diffs))
    torch.testing.assert_close(pred[0], pred_onnx_unfused[0], atol=0.01, rtol=0.01)
    torch.testing.assert_close(pred[1], pred_onnx_unfused[1], atol=0.01, rtol=0.01)
    
    print("running ONNX fused model")
    pred_onnx_fused = onnx_sess_fused.run(None, {"Xfeat_normed": X_features_padded.numpy(), "mask": mask.numpy()})
    pred_onnx_fused = tuple(torch.tensor(p) for p in pred_onnx_fused)
    jets_onnx_fused.append(particles_to_jets(pred_onnx_fused))
    diffs = diffs_vec(pred_onnx_fused)
    print("diffs: {:.4f} {:.4f}".format(*diffs))
    torch.testing.assert_close(pred[0], pred_onnx_fused[0], atol=0.01, rtol=0.01)
    torch.testing.assert_close(pred[1], pred_onnx_fused[1], atol=0.01, rtol=0.01)

In [None]:
import boost_histogram as bh
import mplhep

In [None]:
def sum_overflow_into_last_bin(all_values):
    values = all_values[1:-1]
    values[-1] = values[-1] + all_values[-1]
    values[0] = values[0] + all_values[0]
    return values
    
def to_bh(data, bins, cumulative=False):
    h1 = bh.Histogram(bh.axis.Variable(bins))
    h1.fill(data)
    if cumulative:
        h1[:] = np.sum(h1.values()) - np.cumsum(h1)
    h1[:] = sum_overflow_into_last_bin(h1.values(flow=True)[:])
    return h1

In [None]:
b = np.linspace(10,100,51)
h0 = to_bh(np.concatenate(jets_mlpf), bins=b)
h1 = to_bh(np.concatenate(jets_onnx_unfused), bins=b)
h2 = to_bh(np.concatenate(jets_onnx_fused), bins=b)

In [None]:
mplhep.histplot(h0, label="pytorch", lw=1)
mplhep.histplot(h1, label="onnx unfused", lw=1)
mplhep.histplot(h2, label="onnx fused", lw=1)
plt.legend()
plt.xlabel("Jet pt")

In [None]:
b = np.linspace(10,100,21)
h0 = to_bh(np.concatenate(jets_mlpf), bins=b)
h1 = to_bh(np.concatenate(jets_onnx_unfused), bins=b)
h2 = to_bh(np.concatenate(jets_onnx_fused), bins=b)

plt.plot(h0.axes[0].centers, (h1/h0).values(), marker="o", ms=2.0, lw=1.0)
plt.plot(h0.axes[0].centers, (h2/h0).values(), marker="o", ms=2.0, lw=1.0)
plt.ylim(0.8,1.2)