In [None]:
import hist
import numpy as np
import mplhep as hep
import matplotlib.pyplot as plt
from wremnants import boostHistHelpers as hh
from wremnants import histselections as sel
from wremnants import datasets2016
import lz4.frame
import pickle
hep.style.use(hep.style.ROOT)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
with lz4.frame.open("../mw_with_mu_eta_pt.pkl.lz4") as f:
    boost = pickle.load(f)

In [None]:
def readForDataset(output, name, histname="nominal"):
    if name not in output:
        raise ValueError(f"Sample {name} not in file")
    lumi = 1. if not "dataPostVFP" in output else output["dataPostVFP"]["lumi"]*1000
    info = output[name]
    scale = 1. if output[name]["dataset"]["is_data"] else \
                lumi*info["dataset"]["xsec"]/info["weight_sum"]
    return info["output"][histname]*scale

In [None]:
def unrolledHist(h, obs=["pt", "eta"]):
    bins = np.multiply(*[len(a) for a in h.axes[:2]])
    newh = hist.Hist(hist.axis.Regular(bins, 0, bins), storage=hist.storage.Weight())
    newh[...] = np.ravel(h.project(*obs))
    return newh

In [None]:
def makePlot(datagroups, procs, obs, name, data=None, rrange=[0.9, 1.1], scale=8.5e6):
    width=3 if "unrolled" in obs else 1
    fig = plt.figure(figsize=(8*width,8))
    ax1 = fig.add_subplot(4, 1, (1, 3)) 
    ax2 = fig.add_subplot(4, 1, 4) 
    
    op = lambda x: x.project(obs) 
    if obs == "unrolled":
        op = unrolledHist

    pred = [op(datagroups[k]["hist"]) for k in procs if datagroups[k]["hist"]]
    colors = [datagroups[k]["color"] for k in procs if datagroups[k]["hist"]]
    labels = [datagroups[k]["label"] for k in procs if datagroups[k]["hist"]]
            
    hep.histplot(
        pred,
        histtype="fill",
        color=colors,
        label=labels,
        stack=True,
        ax=ax1
    )
    
    data = datagroups["Data"]["hist"]
    if data:
        data = op(data)
        hep.histplot(
            data,
            histtype="errorbar",
            yerr=True, 
            color="black",
            ax=ax1,
        )
        hep.histplot(
            hh.divideHists(data, sum(pred)),
            histtype="errorbar",
            yerr=False,
            color="black",
            ax=ax2
        )
        
    ax1.set_xlabel("")
    ax2.set_xlabel(name)
    ax1.set_ylabel("Events/bin")
    ax1.set_xticklabels([])
    xrange = [pred[0].axes[0].edges[0], pred[0].axes[0].edges[len(pred[0].axes[0])-1]]
    ax1.set_xlim(xrange)
    ax2.set_xlim(xrange)
    ax2.set_ylabel("data/pred.", fontsize=22)
    ax2.set_ylim(rrange)
    ax1.set_ylim([0, scale])
    ax1.legend(prop={'size' : 20*(0.7 if width == 1 else 1.3)}, ncol=2, loc='upper right')

In [None]:
boost.keys()

pred[0]

In [None]:
datagroups = {
    "Data" : dict(
        members = ["dataPostVFP"],
        color = "black",
        label = "Data",
        hist = None,
    ),
    "Fake" : dict(
        members = list(boost.keys()),
        label = "Nonprompt",
        color = "grey",
        hist = None,
    ),
    "Zmumu" : dict(
        members = ["ZmumuPostVFP"],
        label = r"Z$\to\mu\mu$",
        color = "lightblue",
        hist = None,
    ),   
    "Wtau" : dict(
        members = ["WminustaunuPostVFP", "WplustaunuPostVFP"],
        label = r"W$^{\pm}\to\tau\nu$",
        color = "orange",
        hist = None,
    ),
    "W" : dict(
        members = ["WminusmunuPostVFP", "WplusmunuPostVFP"],
        label = r"W$^{\pm}\to\mu\nu$",
        color = "darkred",
        hist = None,
    ),
    "Ztt" : dict(
        members = ["ZtautauPostVFP"],
        label = r"Z$\to\tau\tau$",
        color = "darkblue",
        hist = None,
    ), 
    "Top" : dict(
        members = ["TTSemileptonicPostVFP", "TTLeptonicPostVFP"],
        label = "Top",
        color = "green",
        hist = None,
    ), 
    "Diboson" : dict(
        members = ["WWPostVFP"],
        label = "Diboson",
        color = "pink",
        hist = None,
    ), 
}
for k,v in datagroups.items():
    for sample in v["members"]:
        try:
            h = readForDataset(boost, sample)
        except ValueError as e:
            print(e)
            continue
        scale = 1 if not (k == "Fake" and "data" not in sample) else -1
        if not v["hist"]:
            v["hist"] = h if k != "Fake" else h.copy()
        else:
            v["hist"] += h*scale
        if k == "Fake":
            hnew = hist.Hist(*v["hist"].axes, storage=hist.storage.Weight())
            vals = v["hist"].values()
            vals[vals<0] = 0
            hnew[...] = np.stack((vals, h.variances()), axis=-1)
            h = hnew
    if v["hist"] and k == "Fake":
        v["hist"] = sel.fakeHistABCD(v["hist"])
    elif v["hist"]:
        v["hist"] = sel.signalHistWmass(v["hist"])

In [None]:
data

In [None]:
mcnames = [x for x in boost.keys() if not boost[x]["dataset"]["is_data"] and x != "Fake"]
prednames = [x for x in datagroups.keys() if x != "Data"]

In [None]:
makePlot(datagroups, prednames, "pt", name=r"p$_{T}$ (GeV)", rrange=[0.95, 1.05])

In [None]:
makePlot(datagroups, prednames, "eta", scale=5e6, name=r"$\eta$", rrange=[0.95, 1.05])

In [None]:
makePlot(datagroups, prednames, "unrolled", name=r"($\eta^{\ell}, p_{\mathrm{T}}^{\ell})$ bin", scale=1.5e5, rrange=[0.9, 1.1])