In [None]:
from coffea.nanoevents import NanoEventsFactory, NanoAODSchema
import awkward as ak
import numpy as np
import hist
import pandas as pd
import uproot
import json
import glob
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import mplhep
from scipy import stats

In [None]:
# https://github.com/DAZSLE/boostedzprime/blob/main/data/pfnanoindex.json
with open("../ULpfnano.json") as fin:
    fileset_in = json.load(fin)

fileset = {
    "flatZprime": fileset_in["2016"]["VectorZPrime"]["VectorZPrimeGammaToQQGamma_flat"],
}
fileset = {
    ds: ["root://cmseos.fnal.gov/" + fn for fn in fl]
    for ds, fl in fileset.items()
}

In [None]:
fileset = {
    "flatZprime": [
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/043730A0-60B4-3040-9FFE-319B9562DBBD.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/14B957AB-B8DB-AF4E-B28C-BBBE96ADD94F.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/2E7DC466-96C1-8846-B5D0-A9FA7886F026.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/3F00B5D0-295D-AA4A-B78E-289B5F96081F.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/60D35927-5A8A-DA45-94B5-B5BFA58645A0.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/60EA58CB-D8D3-2E41-A296-E6267766BC3F.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/690C6CDA-48D6-AD48-A6D7-785B9E55DB1F.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/70C3A007-D5E5-3A46-9F56-F68AD648C2B8.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/93661981-4EAB-6647-9FE1-C9297FCFB68B.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/A4C45E5D-1884-5146-9E12-4BAD12574744.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/D4BBFFBE-4CAD-8747-88F3-A8D8C62B98D2.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2540000/EBFA65CF-A28B-6C4A-9471-EEE8362F6A88.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2560000/6E5C0197-5259-C84F-8CB8-34339CA96C3F.root",
        "root://xrootd-redir.ultralight.org//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/2B09E6AA-3A22-9644-A4DD-41D8DD3E279A.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/3A701A46-43CB-114C-943F-E21D55853BF4.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/575696F9-017B-4E46-8556-D1F3261CB9BF.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/6FC11C92-E386-8145-85F1-5F07E3848CB3.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/71B4D8A8-B3C9-ED4E-AEA7-753FB4A0E885.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/78F89553-C53A-184B-B236-65FD3D58198B.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/7FB550F6-DA5E-E242-B587-B4D711B64AB0.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/8B13BE67-2771-C543-9E62-9CE0205A8C7D.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/9941EB5D-14A7-5644-BD08-CA8FD8E2FD37.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/A7790F4B-FA6F-264A-B6B1-C2C1EC3943C8.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/DDEB3D12-3A72-3A4C-91F6-210AFF8E7344.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/E881B42F-2BE0-434B-962D-10CE0A9EBAE7.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/2810000/EEBDF52C-AC28-1A4C-BED0-033C04A057A0.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/0B210882-FC53-7048-8D69-BA8D6BAC0335.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/27230564-ED41-594F-9C43-C305D1E3D222.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/2D41ED50-7EE3-744D-8C9F-2D9D09A16AF5.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/3D608DCB-9B20-1441-90FD-CDCEDB5225F6.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/4273E361-0BA2-6E4A-A402-F41FE3CEAA71.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/584D74E0-0B63-9E4E-A494-8E99A34C44A6.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/6DF8DDB6-9703-6E4D-A58C-E9CE3A982A1B.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/898E4329-371F-2C4F-91CE-C1699246F613.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/BD0DB056-EC94-224A-BA51-21A7BC27E9C0.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/C69D874A-5FD3-2940-919B-4B2BAC4716B5.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/DA484459-3450-B442-A63C-DA6DC3D01BCA.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/30000/EF7A25B9-6214-1C4A-B996-FD2B148E8715.root",
        "root://cmsxrootd.fnal.gov//store/mc/RunIISummer20UL18NanoAODv9/VectorZPrimeGammaToQQGamma_flat_TuneCP5_13TeV-madgraph-pythia8/NANOAODSIM/106X_upgrade2018_realistic_v16_L1v1-v1/80000/13430DEC-9600-1340-AF9C-867E223C4226.root",
    ],
}

In [None]:
import warnings
warnings.filterwarnings("ignore", "Found duplicate branch")

ds = "flatZprime"
fn = fileset[ds][0]
events = NanoEventsFactory.from_root(fn, metadata={"dataset": ds}, entry_stop=100000).events()

In [None]:
from coffea import processor
from correctionlib import CorrectionSet


class MSDProc(processor.ProcessorABC):
    def __init__(self):
        self.corrs = CorrectionSet.from_file("msdcorr.json")

    def process(self, events):
        zprime = events.GenPart[
            (events.GenPart.pdgId == 55)
            & events.GenPart.hasFlags("fromHardProcess", "isLastCopy")
        ][:, 0]
        zprime["flavor"] = abs(zprime.children[:, 0].pdgId)
        zprime["qcdrho"] = 2 * np.log(zprime.mass / zprime.pt)
        zprime["jet"] = ak.firsts(
            events.FatJet[
                ak.argmin(zprime.delta_r2(events.FatJet), axis=1, keepdims=True)
            ]
        )
        zprime["jet", "msdraw"] = np.sqrt(
            np.maximum(
                0.0,
                (zprime.jet.subjets * (1 - zprime.jet.subjets.rawFactor)).sum().mass2,
            )
        )
        zprime["jet", "msdfjcorr"] = zprime.jet.msdraw / (1 - zprime.jet.rawFactor)

        zprime = zprime[~ak.is_none(zprime.jet)]
        out = {}

        match = zprime.jet.delta_r(zprime) < 0.4
        altmatch = ak.all(zprime.jet.delta_r(zprime.children) < 0.8, axis=1)
        altmatch = altmatch & ak.all(zprime.children.pt > 0.1 * zprime.pt, axis=1)
        out["match"] = (
            hist.Hist.new
            .IntCat([0, 1])
            .IntCat([0, 1])
            .Reg(8, 0, 1536, name="pt", label=r"Z' $p_T$")
            .Reg(8, 0, 512, name="mass", label=r"Z' mass")
            .Double().fill(match, altmatch, zprime.pt, zprime.mass)
        )
        zprime = zprime[altmatch]

        out["kin"] = (
            hist.Hist.new.Reg(64, 0, 1536, name="pt", label=r"Z' $p_T$")
            .Reg(64, 0, 512, name="mass", label=r"Z' mass")
            .Reg(32, -3, 3, name="eta", label=r"Z' $\eta$")
            .Double()
            .fill(pt=zprime.pt, mass=zprime.mass, eta=zprime.eta)
        )
        out["jetkin"] = (
            hist.Hist.new.Reg(64, 0, 1536, name="pt", label=r"Jet $p_T$")
            .Reg(64, 0, 512, name="mass", label=r"Jet mass")
            .Reg(32, -3, 3, name="eta", label=r"Jet $\eta$")
            .Double()
            .fill(pt=zprime.jet.pt, mass=zprime.jet.msoftdrop, eta=zprime.jet.eta)
        )

        out["msd_qcdrho"] = (
            hist.Hist.new.Var([1, 4, 5, 6], name="flavor", flow=False)
            .Reg(32, -8, 0, name="qcdrho", label=r"Z' $\rho=2ln(m/p_T)$")
            .Reg(32, 0, 2, name="msdratio", label="Jet $m_{SD}$ / Z' mass")
            .Double()
            .fill(
                flavor=zprime.flavor,
                qcdrho=2 * np.log(zprime.mass / zprime.pt),
                msdratio=zprime.jet.msoftdrop / zprime.mass,
            )
        )

        zprime = zprime[zprime.mass / zprime.pt < 0.4]

        for mname in ["msdraw", "msdfjcorr", "msoftdrop", "particleNet_mass"]:
            out[mname] = {
                "mean": (
                    hist.Hist.new.Var([1, 4, 5, 6], name="flavor", flow=False)
                    .Reg(48, 200, 1500, name="pt", label=r"Jet $p_T$")
                    .Reg(48, 0, 400, name="mass", label=r"Z' mass")
                    .Reg(48, -2.5, 2.5, name="eta", label=r"Jet $\eta$")
                    .Reg(4, 0, 2, name="msdratio", label="Jet $m_{SD}$ / Z' mass")
                    .Mean()
                    .fill(
                        flavor=zprime.flavor,
                        pt=zprime.jet.pt,
                        mass=zprime.mass,
                        eta=zprime.jet.eta,
                        msdratio=zprime.jet[mname] / zprime.mass,
                        sample=zprime.jet[mname],
                    )
                ),
                "massmean": (
                    hist.Hist.new.Var([1, 4, 5, 6], name="flavor", flow=False)
                    .Reg(48, 200, 1500, name="pt", label=r"Jet $p_T$")
                    .Reg(48, 0, 400, name="mass", label=r"Z' mass")
                    .Reg(48, -2.5, 2.5, name="eta", label=r"Jet $\eta$")
                    .Reg(4, 0, 2, name="msdratio", label="Jet $m_{SD}$ / Z' mass")
                    .Mean()
                    .fill(
                        flavor=zprime.flavor,
                        pt=zprime.jet.pt,
                        mass=zprime.mass,
                        eta=zprime.jet.eta,
                        msdratio=zprime.jet[mname] / zprime.mass,
                        sample=zprime.mass,
                    )
                ),
            }
            # closure check
            try:
                corr = self.corrs[mname]
                # corr = self.corrs[mname + "_onebin"]
            except (IndexError, KeyError):
                continue
            cvar = corr.evaluate(
                np.array(zprime.jet[mname] / zprime.jet.pt),
                np.array(np.log(zprime.jet.pt)),
                np.array(zprime.jet.eta),
            )
            out[mname]["cmean"] = (
                hist.Hist.new.Var([1, 4, 5, 6], name="flavor", flow=False)
                .Reg(48, 200, 1500, name="pt", label=r"Jet $p_T$")
                .Reg(48, 0, 400, name="mass", label=r"Z' mass")
                .Reg(48, -2.5, 2.5, name="eta", label=r"Jet $\eta$")
                .Reg(4, 0, 2, name="msdratio", label="Jet $m_{SD}$ / Z' mass")
                .Mean()
                .fill(
                    flavor=zprime.flavor,
                    pt=zprime.jet.pt,
                    mass=zprime.mass,
                    eta=zprime.jet.eta,
                    msdratio=zprime.jet[mname] * cvar / zprime.mass, # Note: using corrected for window
                    sample=zprime.jet[mname] * cvar,
                )
            )
        return out

    def postprocess(self, x):
        return x

In [None]:
out = MSDProc().process(events)

In [None]:
runner = processor.Runner(
    processor.FuturesExecutor(workers=4),
    schema=NanoAODSchema,
    # xrootdtimeout=120,
)
out = runner(fileset, "Events", MSDProc())

In [None]:
import coffea.util
coffea.util.save(out, "profiles.coffea")

In [None]:
m = out["match"][:, :, ::sum, ::sum].view()
m /= m.sum()
m

In [None]:
def splom(h):
    """Corner/splom plot

    https://github.com/scikit-hep/hist/issues/381
    """
    naxes = len(h.axes)
    fig, axes = plt.subplots(naxes, naxes, figsize=(4*naxes, 4*naxes), facecolor="w")
    for i, axrow in enumerate(axes):
        for j, ax in enumerate(axrow):
            if j > i:
                ax.axis("off")
            elif j == i:
                hp = h.project(h.axes[i].name)
                hp.plot(ax=ax)
                ax.set_xlim(hp.axes[0].edges[0], hp.axes[0].edges[-1])
            else:
                hp = h.project(h.axes[j].name, h.axes[i].name)
                hp.plot(ax=ax, cbar=False)
                ax.set_xlim(hp.axes[0].edges[0], hp.axes[0].edges[-1])
                ax.set_ylim(hp.axes[1].edges[0], hp.axes[1].edges[-1])

    fig.tight_layout()
    return fig

In [None]:
fig = splom(out["kin"])
fig.savefig("zprime_kinematics.pdf")

In [None]:
fig = splom(out["jetkin"])
fig.savefig("jet_kinematics.pdf")

In [None]:
fig, ax = plt.subplots()

h = out["msd_qcdrho"][::sum, :, :]
H = h.counts()
h.view()[:] = H / H.sum(axis=1)[:, None]
art = h.plot(ax=ax, norm=LogNorm(vmin=1e-3, vmax=1))
art.cbar.set_label("Y-density")

window = slice(0.5j, 1.5j, sum)
ax.axhline(0.5, linestyle="--", color="orange", label="Avg. window")
ax.axhline(1.5, linestyle="--", color="orange")
ax.axvline(2*np.log(0.8/2), linestyle="--", color="red", label="AK8 cone threshold")
ax.axvline(-2.1, linestyle=":", color="red", label=r"$-6 < \rho < -2.1$")
ax.axvline(-6, linestyle=":", color="red")
ax.legend(loc="upper right")
fig.savefig("response_vs_rho.pdf")

In [None]:
from typing import List, Tuple
from scipy.optimize import lsq_linear
import correctionlib.schemav2 as clib


def ndpolyfit(
    points: List[np.ndarray],
    values: np.ndarray,
    weights: np.ndarray,
    varnames: List[str],
    degree: Tuple[int],
) -> clib.Formula:
    """Fit an n-dimensional polynomial to data points with weight

    Returns a correctionlib Formula node
    """
    if len(values.shape) != 1:
        raise ValueError("Expecting flat array of values")
    if not all(x.shape == values.shape for x in points):
        raise ValueError("Incompatible shapes for points and values")
    if values.shape != weights.shape:
        raise ValueError("Incompatible shapes for values and weights")
    if len(points) != len(varnames):
        raise ValueError("Dimension mismatch between points and varnames")
    if len(degree) != len(varnames):
        raise ValueError("Dimension mismatch between varnames and degree")
    if len(degree) > 4:
        raise NotImplementedError(
            "correctionlib Formula not available for more than 4 variables?"
        )
    degree = np.array(degree, dtype=int)
    npoints = len(values)
    powergrid = np.ones(shape=(npoints,) + tuple(degree + 1))
    for i, (x, deg) in enumerate(zip(points, degree)):
        shape = np.ones(1 + len(degree), dtype=int)
        shape[0] = npoints
        shape[i + 1] = deg + 1
        powergrid *= np.power.outer(x, np.arange(deg + 1)).reshape(shape)
    fit = lsq_linear(
        A=powergrid.reshape(npoints, -1) * weights[:, None],
        b=values * weights,
    )
    print(fit.message)
    dof = npoints - np.prod(degree + 1)
    prob = stats.chi2.sf(fit.cost, df=dof)
    print(f"chi2 = {fit.cost}, P({dof=}) = {prob:.3f}")
    params = fit.x.reshape(degree + 1)
    # TODO: n-dimensional Horner form
    expr = []
    for index in np.ndindex(tuple(degree + 1)):
        term = [str(params[index])] + [
            f"{var}^{p}" if p > 1 else var
            for var, p in zip("xyzt", index)
            if p > 0
        ]
        expr.append("*".join(term))
    return clib.Formula(
        nodetype="formula",
        expression="+".join(expr),
        parser="TFormula",
        variables=varnames,
    )


titles = {
    "msdraw": "Softdrop mass (raw)",
    "msdfjcorr": "Softdrop mass (jet JEC)",
    "msoftdrop": "Softdrop mass (subjet JEC)",
    "particleNet_mass": "ParticleNet regressed mass"
}
cset = clib.CorrectionSet(schema_version=2, corrections=[])

for mname, mdesc in titles.items():
    if mname == "particleNet_mass":
        continue
    msdhists = out[mname]
    # flavor, pt, mass, eta, msdratio
    hmreco = msdhists["mean"][::sum, :, :, :, window]
    hmtrue = msdhists["massmean"][::sum, :, :, :, window]

    ptgrid, massgrid, etagrid = np.meshgrid(*[ax.centers for ax in hmreco.axes], indexing="ij")
    msdgrid = hmreco.values()
    rgrid = msdgrid / ptgrid
    dgrid = np.log(ptgrid)

    mask = (hmreco.counts() > 5) & (rgrid < np.exp(-1.8/2)) # & (rgrid > np.exp(-6/2))
    msdcorr = np.ones_like(msdgrid)
    msdcorr[mask] = hmtrue.values()[mask] / hmreco.values()[mask]
    msdcorrw = np.zeros_like(msdgrid)
    msdcorrw[mask] = 1 / ( msdcorr[mask] * np.sqrt(hmreco.variances()[mask]) / hmreco.values()[mask] )

    print(f"=== {mname}")
    m2 = mask
    formula = ndpolyfit(
        points=[rgrid[m2], dgrid[m2], etagrid[m2]],
        values=msdcorr[m2],
        weights=msdcorrw[m2],
        varnames=["mdivpt", "logpt", "eta"],
        degree=(3, 2, 6),
    )
    cset.corrections.append(
        clib.Correction(
            name=mname + "_onebin",
            description=f"Correction to {mname} '{mdesc}' fit to polynomial",
            version=1,
            inputs=[
                clib.Variable(name="mdivpt", type="real", description="{mname} divided by jet pt"),
                clib.Variable(name="logpt", type="real", description="log(jet pt)"),
                clib.Variable(name="eta", type="real", description="jet eta"),
            ],
            output=clib.Variable(name="output", type="real", description=f"Multiplicative correction to {mname}"),
            data=formula,
        )
    )

    # equality in eta mask on purpose to help fits be continuous
    print(f"=== {mname} low eta")
    m2 = mask & (etagrid <= -1.25)
    formulalo = ndpolyfit(
        points=[rgrid[m2], dgrid[m2], etagrid[m2]],
        values=msdcorr[m2],
        weights=msdcorrw[m2],
        varnames=["mdivpt", "logpt", "eta"],
        degree=(2, 2, 2),
    )
    
    print(f"=== {mname} mid eta")
    m2 = mask & (etagrid >= -1.25) & (etagrid <= 1.25)
    formulamid = ndpolyfit(
        points=[rgrid[m2], dgrid[m2], etagrid[m2]],
        values=msdcorr[m2],
        weights=msdcorrw[m2],
        varnames=["mdivpt", "logpt", "eta"],
        degree=(3, 2, 4),
    )
    
    print(f"=== {mname} hi eta")
    m2 = mask & (etagrid >= 1.25)
    formulahi = ndpolyfit(
        points=[rgrid[m2], dgrid[m2], etagrid[m2]],
        values=msdcorr[m2],
        weights=msdcorrw[m2],
        varnames=["mdivpt", "logpt", "eta"],
        degree=(2, 2, 2),
    )
    
    cset.corrections.append(
        clib.Correction(
            name=mname,
            description=f"Correction to {mname} '{mdesc}' fit to polynomial in three eta bins",
            version=1,
            inputs=[
                clib.Variable(name="mdivpt", type="real", description=f"{mname} divided by jet pt"),
                clib.Variable(name="logpt", type="real", description="log(jet pt)"),
                clib.Variable(name="eta", type="real", description="jet eta"),
            ],
            output=clib.Variable(name="output", type="real", description=f"Multiplicative correction to {mname}"),
            data=clib.Binning(
                nodetype="binning",
                input="eta",
                edges=[-3, -1.25, 1.25, 3],
                content=[formulalo, formulamid, formulahi],
                flow="clamp",
            )
        )
    )

In [None]:
# with open("msdcorr.json", "w") as fout:
#     fout.write(cset.json(exclude_unset=True))
# cset = cset.to_evaluator()

In [None]:
fig, ax = plt.subplots()

mname = "msdraw"
msdhists = out[mname]
title = titles[mname]

# flavor, pt, mass, eta, msdratio
hnum = msdhists["mean"][::sum, :, :, :, window]
hden = msdhists["massmean"][::sum, :, :, :, window]

mask = (hnum.counts() > 10)
msdgrid = np.where(mask, hnum.values(), massgrid)
msdcorr = np.ones_like(msdgrid)
msdcorr[mask] = hden.values()[mask] / hnum.values()[mask]
msdcorrw = np.zeros_like(msdgrid)
msdcorrw[mask] = 1 / ( msdcorr[mask] * np.sqrt(hnum.variances()[mask]) / hnum.values()[mask] )

x = rgrid
ax.scatter(
    x[mask],
    msdcorr[mask],
    s=1/msdcorrw[mask],
    label="Before",
)

if "cmean" in msdhists:
    hcnum = msdhists["cmean"][::sum, :, :, :, window]
    ax.scatter(
        x[mask],
        hden.values()[mask] / hcnum.values()[mask],
        s=1/msdcorrw[mask],
        label="After",
    )
ax.legend()

In [None]:
def plresp(hnum, hden, ax, **pl):
    resp = hnum.values() / hden.values()
    # denom is negligible
    resp_var = hnum.variances() / hden.values()**2
    eb = ax.errorbar(
        x=hnum.axes[0].centers,
        y=resp,
        yerr=np.sqrt(resp_var),
        **pl,
    )
    return eb


fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True, facecolor="w")

for ax, (mname, title) in zip(axes, titles.items()):
    msdhists = out[mname]
    ax.set_title(title)

    # flavor, pt, mass, eta, msdratio
    hnum = msdhists["mean"][::sum, :, 60j:100j:sum, -1.3j:1.3j:sum, window]
    hden = msdhists["massmean"][::sum, :, 60j:100j:sum, -1.3j:1.3j:sum, window]
    plresp(hnum, hden, ax, label="$|\eta| < 1.3$")

    hnum = (
        msdhists["mean"][::sum, :, 60j:100j:sum, :-1.3j:sum, window]
        + msdhists["mean"][::sum, :, 60j:100j:sum, 1.3j::sum, window]
    )
    hden = (
        msdhists["massmean"][::sum, :, 60j:100j:sum, :-1.3j:sum, window]
        + msdhists["massmean"][::sum, :, 60j:100j:sum, 1.3j::sum, window]
    )
    plresp(hnum, hden, ax, label="$|\eta| > 1.3$")

    ax.set_ylabel("Response $<m_{SD}> / <m_{Z'}>$")
    ax.set_xlabel(hnum.axes[0].label)
    ax.legend(title="$60 < m_{SD} < 100$")
    ax.set_ylim(0.8, 1.2)

In [None]:
fig, ax = plt.subplots(facecolor="w")
mname = "msdraw"
msdhists = out[mname]
title = titles[mname]
ax.set_title(title)

pt, msd = slice(600j, 1000j, sum), slice(100j, 300j, sum)
# flavor, pt, mass, eta, msdratio
hden = msdhists["massmean"][::sum, pt, msd, :, window]

hnum = msdhists["mean"][::sum, pt, msd, :, window]
plresp(hnum, hden, ax, label="Before")

hnum = msdhists["cmean"][::sum, pt, msd, :, window]
plresp(hnum, hden, ax, label="After")

ax.set_ylabel("Response $<m_{SD}> / <m_{Z'}>$")
ax.set_xlabel(hnum.axes[0].label)
ax.legend(title="Corrections")
ax.axvline(-1.25, linestyle=":")
ax.axvline(1.25, linestyle=":")
ax.axhline(1, linestyle=":", color="grey")

etas = np.linspace(-2.5, 2.5, 50)
ax.plot(etas, 1/MSDProc().corrs["msdraw"].evaluate(100/800, np.log(800), etas))

ax.set_ylim(0.9, 1.1)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(12, 4), sharey=True, facecolor="w")

for ax, mname in zip(axes, titles):
    msdhists = out[mname]
    title = titles[mname]
    ax.set_title(title)
    
    ptwindow = slice(600j, 1500j, sum)
    masswindow = slice(100j, 300j, sum)

    # flavor, pt, mass, eta, msdratio
    hden = msdhists["massmean"][::sum, ptwindow, masswindow, :, window]

    hnum = msdhists["mean"][::sum, ptwindow, masswindow, :, window]
    plresp(hnum, hden, ax, label="Before")

    hnum = msdhists["cmean"][::sum, ptwindow, masswindow, :, window]
    plresp(hnum, hden, ax, label="After")

    ax.set_ylabel("Response $<m_{SD}> / <m_{Z'}>$")
    ax.set_xlabel(hnum.axes[0].label)
    ax.legend(title="Corrections")
    ax.axvline(-1.25, linestyle=":")
    ax.axvline(1.25, linestyle=":")
    ax.axhline(1, linestyle=":", color="grey")
    ax.set_ylim(0.9, 1.1)

In [None]:
def addrhos(ax):
    ptval = np.linspace(0, 1024+512, 20)
    mval = 0.8 * ptval / 2
    ax.plot(ptval, mval, linestyle="--", color="r", label="AK8 cone")
    mval = np.exp(-2.1/2)*ptval
    ax.plot(ptval, mval, linestyle="--", color="k", label=r"$\rho=-2.1$")
    mval = np.exp(-6.0/2)*ptval
    ax.plot(ptval, mval, linestyle=":", color="k", label=r"$\rho=-6.0$")

In [None]:
fig, ax = plt.subplots(facecolor="w")

ax.set_title(title)

# flavor, pt, mass, eta, msdratio
hnum = msdhists["mean"][::sum, :, :, ::sum, window]
hden = msdhists["massmean"][::sum, :, :, ::sum, window]

art = hist.Hist(
    *hnum.axes,
    data=hnum.counts(), #hnum.values() / hden.values(),
).plot(ax=ax)#, vmin=0.9, vmax=1.1)
art.cbar.set_label("Response $<m_{SD}> / <m_{Z'}>$")

addrhos(ax)
ax.legend(loc="upper left")

In [None]:
fig, ax = plt.subplots(facecolor="w")

mn = "particleNet_mass"
ax.set_title(titles[mn])

# flavor, pt, mass, eta, msdratio
hnum = out[mn]["mean"][::sum, :, :, ::sum, window]
hden = out[mn]["massmean"][::sum, :, :, ::sum, window]
ax.set_title(titles[mn])


avg = hnum.sum().value / hden.sum().value
art = hist.Hist(
    *hnum.axes,
    data=hnum.values() / hden.values(),
).plot(ax=ax, cmap="bwr", vmin=0.8, vmax=1.2)
art.cbar.set_label("Response $<m_{SD}> / <m_{Z'}>$" + f" (avg {avg:.2f})")
# art.cbar.set_label(f"Response diff (vs. unity)")

addrhos(ax)

ax.legend(loc="upper left")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True, facecolor="w")

mn = "particleNet_mass"
ax.set_title(titles[mn])

fnames = {0: "light jets", 1: "c jets", 2: "b jets"}

for i, (ax, flavor) in enumerate(zip(axes, fnames)):
    hnumf = out[mn]["mean"][flavor, :, :, ::sum, window]
    hdenf = out[mn]["massmean"][flavor, :, :, ::sum, window]
    ax.set_title(titles[mn] + ", " + fnames[flavor])

    avg = hnumf.sum().value / hdenf.sum().value
    art = hist.Hist(
        *hnum.axes,
        data=hnumf.values() / hdenf.values(),
    ).plot(ax=ax, cmap="bwr", vmin=0.8, vmax=1.2, cbar=False if i < 2 else True)
    if i == 2:
        art.cbar.set_label("Response $<m_{SD}> / <m_{Z'}>$")


    addrhos(ax)
    ax.legend(loc="upper left")

In [None]:
def err(meanhist):
    return np.sqrt(meanhist.variances()*meanhist.counts())

fig, ax = plt.subplots(facecolor="w")

mn = "msoftdrop"
ax.set_title(titles[mn])

# flavor, pt, mass, eta, msdratio
hmean = out[mn]["cmean"][::sum, :, :, ::sum, window]

art = hist.Hist(
    *hmean.axes,
    data=err(hmean),
).plot(ax=ax, cmin=0, cmax=50)
art.cbar.set_label("Response-corr. resolution $\sigma(m_{SD})$ [GeV]")

In [None]:
def err(meanhist):
    return np.sqrt(meanhist.variances()*meanhist.counts())

fig, ax = plt.subplots(facecolor="w")

# mn2 = "msdfjcorr"
mn2 = "msoftdrop"
# mn2 = "msdraw"
# mn2 = "particleNet_mass"
ax.set_title(titles[mn2] + " vs. " + titles[mn])

# flavor, pt, mass, eta, msdratio
hmean2 = out[mn2]["mean"][::sum, :, :, ::sum, window]

art = hist.Hist(
    *hmean.axes,
    data=(err(hmean2) / err(hmean)) * (hmean.values() / hmean2.values()),
).plot(ax=ax, cmap="bwr", vmin=0.5, vmax=1.5)
art.cbar.set_label("Relative resolution")

addrhos(ax)
ax.legend()

In [None]:
def err(meanhist):
    return np.sqrt(meanhist.variances()*meanhist.counts())

fig, ax = plt.subplots(facecolor="w")

# mn2 = "msdfjcorr"
# mn2 = "msoftdrop"
# mn2 = "msdraw"
mn2 = "particleNet_mass"
ax.set_title(titles[mn2] + " vs. " + titles[mn])

# flavor, pt, mass, eta, msdratio
hmean2 = out[mn2]["mean"][::sum, :, :, ::sum, window]

art = hist.Hist(
    *hmean.axes,
    data=hmean2.counts() / hmean.counts(),
).plot(ax=ax, cmap="bwr", vmin=0.8, vmax=1.2)
art.cbar.set_label("Relative counts")

addrhos(ax)
ax.legend()