In [1]:
import awkward as ak
import hist
import jax
import matplotlib.pyplot as plt
import numpy as np
import uproot
import vector

vector.register_awkward()
ak.jax.register_and_check()

In [2]:
ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

def correct_jets(jets, alpha):
    """perform jet energy calibration with nuisance parameter alpha"""
    jets = ak.Array(jets)
    new_pt = jets["pt"] + 25*alpha
    jets["pt"] = new_pt
    return jets

def get_mass(jets):
    """get dijet mass using system of first two jets"""
    return (jets[:, 0] + jets[:, 1]).mass  # returning (jets[:, 0]).mass causes a different errror with jax

def pipeline(jets, a):
    """analysis pipeline: calculate mean of dijet masses"""
    return np.mean(get_mass(correct_jets(jets, a)))


with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Jet_pt","Jet_eta", "Jet_phi", "Jet_mass"])
    evtfilter = ak.num(arr["Jet_pt"]) >= 2
    jets = ak.zip(dict(zip(["pt","eta", "phi", "mass"], ak.unzip(arr))), with_name="Momentum4D")[evtfilter]
    jets = ak.to_backend(jets, "jax")

    
# scan the analysis pipeline for various nuisance parameter values
np_vals = np.linspace(-5, 5, 20) 
mass_means = [pipeline(jets, a) for a in np_vals]  # this causes issues with jax backend

jax.value_and_grad(pipeline, argnums=1)(jets, a)  # this is what we would ultimately want

TypeError: Unexpected input type for array: <class 'awkward.contents.numpyarray.NumpyArray'>

This error occurred while calling

    numpy.add.__call__(
        <MomentumArray4D [{eta: -3.1967773, ...}, ..., {...}] type='140 * M...'>
        <MomentumArray4D [{eta: 1.2490234, ...}, ..., {...}] type='140 * Mo...'>
    )

In [None]:
fig, ax = plt.subplots()
ax.plot(np_vals, mass_means)
ax.set_xlabel("NP value")
ax.set_ylabel("")

Our input file is tiny. The following cell could be used instead to process 500 MB of data.

In [None]:
import os

ttbar_file = "https://xrootd-local.unl.edu:1094//store/user/AGC/nanoAOD/"\
    "TT_TuneCUETP8M1_13TeV-powheg-pythia8/cmsopendata2015_ttbar_19981_PU25nsData2015v1_76X_"\
    "mcRun2_asymptotic_v12_ext4-v1_80000_0007.root"

# download for subsequent use
local_file_name = "ttbar.root"
if not os.path.exists(local_file_name):
    urllib.request.urlretrieve(ttbar_file, filename=local_file_name)