In [None]:
import os

import pandas as pd
from tqdm import tqdm

from diquark import DATA_KEYS, PATH_DICT, CROSS_SECTION_DICT
from diquark.helpers import create_data_dict, get_col
from diquark.load import read_jet_delphes
from diquark.features import (
    jet_multiplicity,
    leading_jet_arr,
    calculate_delta_r,
    combined_invariant_mass,
    three_jet_invariant_mass,
)
from diquark.plotting import make_histogram, make_histogram_with_double_gaussian_fit


if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir("..")

In [None]:
datasets = {key: read_jet_delphes(PATH_DICT[key]) for key in tqdm(DATA_KEYS)}

In [None]:
jet_multiplicities = {key: jet_multiplicity(ds) for key, ds in tqdm(datasets.items())}

In [None]:
jet_pts = {key: leading_jet_arr(data, key="Jet/Jet.PT") for key, data in tqdm(datasets.items())}
jet_etas = {key: leading_jet_arr(data, key="Jet/Jet.Eta") for key, data in tqdm(datasets.items())}
jet_phis = {key: leading_jet_arr(data, key="Jet/Jet.Phi") for key, data in tqdm(datasets.items())}

In [None]:
combined_masses = {key: combined_invariant_mass(arr) for key, arr in tqdm(datasets.items())}

In [None]:
delta_rs = {}
for key, data in tqdm(datasets.items()):
    etas = leading_jet_arr(data, 6, key="Jet/Jet.Eta")
    phis = leading_jet_arr(data, 6, key="Jet/Jet.Phi")
    pts = leading_jet_arr(data, 6, key="Jet/Jet.PT")

    delta_rs[key] = calculate_delta_r(etas, phis, pts)

In [None]:
m3j_s = {}
for key, data in tqdm(datasets.items()):
    m3j_s[key] = three_jet_invariant_mass(data)

In [None]:
fig = make_histogram(combined_masses, 20, clip_top_prc=100)
bin_width = fig.data[0].x[1] - fig.data[0].x[0]
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=900 * (2 / 3),
    # ignore first bin
    xaxis_range=[fig.data[0].x[1] - bin_width / 2, fig.data[0].x[-1] + bin_width / 2],
    yaxis_type="log",
)

fig.show()
print(
    [fig.data[0].x[1] - bin_width, fig.data[0].x[-1]],
)

In [None]:
ds = create_data_dict(
    **{
        "multiplicity": jet_multiplicities,
        "delta_R": delta_rs,
        "m3j": m3j_s,
        "inv_mass": combined_masses,
        "pt": jet_pts,
        "eta": jet_etas,
        "phi": jet_phis,
    }
)

In [None]:
df = pd.DataFrame(ds)
df["target"] = df["Truth"].apply(lambda x: 1 if "SIG" in x else 0)

In [None]:
df.head()

In [None]:
df.to_parquet("data/full_sample.parquet", index=False)

# Data Visualization

In [None]:
fig = make_histogram(jet_pts, 20, col=0, clip_top_prc=100)
bin_width = fig.data[0].x[1] - fig.data[0].x[0]
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=900 * (2 / 3),
    # ignore first bin
    xaxis_range=[fig.data[0].x[1] - bin_width / 2, fig.data[0].x[-1] + bin_width / 2],
    yaxis_type="log",
)

fig.show()

In [None]:
suu_mass = {"SIG:suu": combined_masses["SIG:suu"]}
fig = make_histogram_with_double_gaussian_fit(suu_mass, 20, clip_top_prc=100, cross=None)
bin_width = fig.data[0].x[1] - fig.data[0].x[0]
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="probability density",
    barmode="stack",
    bargap=0,
    width=1300 * (2 / 3),
    height=1300 * (2 / 3),
    # ignore first bin
    xaxis_range=[fig.data[0].x[1] - bin_width / 2, fig.data[0].x[-1] + bin_width / 2],
    yaxis_type="log",
)
fig.update_legends(
    title_text="",
    itemsizing="constant",
    yanchor="top",
    y=0.1,
    xanchor="left",
    x=0.01,
    font=dict(size=16),
)
fig.show()
# fig.write_image("suu_mass.pdf")