In [None]:
import os

import awkward as ak
import dask
import dask_awkward as dak
import hist
import numpy as np
from coffea.lumi_tools import LumiMask
from coffea.nanoevents import NanoEventsFactory
from distributed import Client
from hist.dask import Hist

In [None]:
ptbins = [
    5,
    10,
    15,
    20,
    22,
    26,
    28,
    30,
    32,
    34,
    36,
    38,
    40,
    45,
    50,
    60,
    80,
    100,
    150,
    250,
    400,
]

In [None]:
events = NanoEventsFactory.from_root(
    {"root_files/d45e45dd-5dd8-4cf3-a03a-141a6ff45d44.root": "Events"},
    permit_dask=True,
    chunks_per_file=10,
).events()

In [None]:
def filter_events(events, pt):
    events = events[dak.num(events.Electron) >= 2]
    abs_eta = abs(events.Electron.eta)
    pass_eta_ebeegap = (abs_eta < 1.4442) | (abs_eta > 1.566)
    pass_tight_id = events.Electron.cutBased == 4
    pass_pt = events.Electron.pt > pt
    pass_eta = abs_eta <= 2.5
    pass_selection = pass_pt & pass_eta & pass_eta_ebeegap & pass_tight_id
    n_of_tags = dak.sum(pass_selection, axis=1)
    good_events = events[n_of_tags >= 2]
    good_locations = pass_selection[n_of_tags >= 2]

    return good_events, good_locations


def trigger_match(electrons, trigobjs, pt):
    pass_pt = trigobjs.pt > pt
    pass_id = abs(trigobjs.id) == 11
    pass_wptight = trigobjs.filterBits & (0x1 << 1) == 2
    trigger_cands = trigobjs[pass_pt & pass_id & pass_wptight]

    delta_r = electrons.metric_table(trigger_cands)
    pass_delta_r = delta_r < 0.1
    n_of_trigger_matches = dak.sum(dak.sum(pass_delta_r, axis=1), axis=1)
    trig_matched_locs = n_of_trigger_matches >= 1

    return trig_matched_locs


def find_probes(tags, probes, trigobjs, pt):
    trig_matched_tag = trigger_match(tags, trigobjs, pt)
    tags = tags[trig_matched_tag]
    probes = probes[trig_matched_tag]
    trigobjs = trigobjs[trig_matched_tag]

    dr = tags.delta_r(probes)
    mass = (tags + probes).mass

    in_mass_window = abs(mass - 91.1876) < 30
    opposite_charge = tags.charge * probes.charge == -1

    isZ = in_mass_window & opposite_charge
    dr_condition = dr > 0.0

    all_probes = probes[isZ & dr_condition]
    trig_matched_probe = trigger_match(all_probes, trigobjs, pt)
    passing_probes = all_probes[trig_matched_probe]

    return passing_probes, all_probes


def perform_tnp(events, pt=32):
    good_events, good_locations = filter_events(events, pt)
    ele_for_tnp = good_events.Electron[good_locations]

    zcands1 = dak.combinations(ele_for_tnp, 2, fields=["tag", "probe"])
    zcands2 = dak.combinations(ele_for_tnp, 2, fields=["probe", "tag"])
    p1, a1 = find_probes(zcands1.tag, zcands1.probe, good_events.TrigObj, pt)
    p2, a2 = find_probes(zcands2.tag, zcands2.probe, good_events.TrigObj, pt)

    return p1, a1, p2, a2

In [None]:
p1, a1, p2, a2 = perform_tnp(events)

ptaxis = hist.axis.Variable(ptbins, name="pt")
hpt_all = Hist(ptaxis)
hpt_pass = Hist(ptaxis)

hpt_all.fill(dak.flatten(a1.pt))
hpt_pass.fill(dak.flatten(p1.pt))

hpt_all.fill(dak.flatten(a2.pt))
hpt_pass.fill(dak.flatten(p2.pt))

client = Client()
tmp1 = dask.compute(hpt_all, hpt_pass)
tmp2 = dask.compute(hpt_all, hpt_pass)
tmp3 = dask.compute(hpt_all, hpt_pass)
tmp4 = dask.compute(hpt_all, hpt_pass)

for tmp in [tmp1, tmp2, tmp3, tmp4]:
    print(tmp[0].values(flow=True))
    print(tmp[1].sum(flow=True))
    print()