In [None]:
import os

import numpy as np
import pandas as pd
from tqdm import tqdm

import plotly.express as px

from diquark.constants import DATA_KEYS, PATH_DICT_ATLAS_130_85
from diquark.helpers import create_data_dict
from diquark.load import read_jet_delphes
from diquark.features import (
    jet_multiplicity,
    leading_jet_arr,
    calculate_delta_r,
    combined_invariant_mass,
    n_jet_invariant_mass,
    n_jet_vector_sum_pt,
)
from diquark.plotting import make_histogram, make_histogram_with_double_gaussian_fit


if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir("..")

In [None]:
datasets = {key: read_jet_delphes(PATH_DICT_ATLAS_130_85[key]) for key in tqdm(DATA_KEYS)}

In [None]:
jet_multiplicities = {key: jet_multiplicity(ds) for key, ds in tqdm(datasets.items())}

In [None]:
jet_pts = {key: leading_jet_arr(data, key="Jet/Jet.PT") for key, data in tqdm(datasets.items())}
jet_etas = {key: leading_jet_arr(data, key="Jet/Jet.Eta") for key, data in tqdm(datasets.items())}
jet_phis = {key: leading_jet_arr(data, key="Jet/Jet.Phi") for key, data in tqdm(datasets.items())}

In [None]:
combined_masses = {key: combined_invariant_mass(arr) for key, arr in tqdm(datasets.items())}

In [None]:
delta_rs = {}
avg_delta_rs = {}
m3j_s = {}
m3j_m6j = {}
m2j_s = {}
m2j_m6j = {}
max_delta_rs = {}
smallest_delta_r_masses = {}
n_jet_pairs_near_W_mass = {}
max_vector_sum_pt = {}
max_vector_sum_pt_delta_r = {}

for key, data in tqdm(datasets.items()):
    etas = leading_jet_arr(data, 6, key="Jet/Jet.Eta")
    phis = leading_jet_arr(data, 6, key="Jet/Jet.Phi")
    pts = leading_jet_arr(data, 6, key="Jet/Jet.PT")

    # Calculate Î”R for each pair of jets
    delta_rs[key] = calculate_delta_r(etas, phis, pts)
    avg_delta_rs[key] = np.mean(delta_rs[key], where=delta_rs[key] != 0)
    max_delta_rs[key] = np.max(delta_rs[key], axis=1)

    # Calculate invariant masses for 3-jet combinations
    m3j_s[key] = n_jet_invariant_mass(data, n=6, k=3)
    m3j_m6j[key] = np.divide(
        m3j_s[key].mean(axis=-1, where=m3j_s[key] != 0),
        combined_masses[key],
        out=np.zeros_like(combined_masses[key]),
        where=combined_masses[key] != 0,
    )

    # Calculate invariant masses for 2-jet combinations
    m2j_s[key] = n_jet_invariant_mass(data, n=6, k=2)
    m2j_m6j[key] = np.divide(
        m2j_s[key].mean(axis=-1, where=m2j_s[key] != 0),
        combined_masses[key],
        out=np.zeros_like(combined_masses[key]),
        where=combined_masses[key] != 0,
    )

    # Find the mass of the jet pair with the smallest Î”R
    smallest_delta_r_indices = np.argmin(delta_rs[key], axis=1)
    smallest_delta_r_masses[key] = np.choose(smallest_delta_r_indices, m2j_s[key].T)

    # Count jet pairs within 20 GeV of the W mass
    n_jet_pairs_near_W_mass[key] = np.sum((m2j_s[key] >= 60) & (m2j_s[key] <= 100), axis=1)

    # Calculate vector sum pT for 2-jet combinations
    vector_sum_pts = n_jet_vector_sum_pt(data, n=6, k=2)
    max_vector_sum_pt[key] = np.max(vector_sum_pts, axis=1)

    # Indices of the jet pairs (flat index across the combination matrix)
    jet_pair_indices = np.argmax(vector_sum_pts, axis=1)

    # calculate the Î”R between the two jets with the largest vector sum pT
    max_vector_sum_pt_delta_r[key] = np.choose(jet_pair_indices, delta_rs[key].T)

In [None]:
ds = create_data_dict(
    **{
        "multiplicity": jet_multiplicities,
        "delta_R": delta_rs,
        "m3j": m3j_s,
        "m2j_s": m2j_s,
        "inv_mass": combined_masses,
        "m3j_m6j": m3j_m6j,
        "m2j_m6j": m2j_m6j,
        "pt": jet_pts,
        "eta": jet_etas,
        "phi": jet_phis,
        "max_delta_R": max_delta_rs,
        "m2j_min_delta_R": smallest_delta_r_masses,
        "nj_mW_pm20": n_jet_pairs_near_W_mass,
        "max_vector_sum_pt": max_vector_sum_pt,
        "max_vector_sum_pt_delta_r": max_vector_sum_pt_delta_r,
    }
)

In [None]:
df = pd.DataFrame(ds)
df["target"] = df["Truth"].apply(lambda x: 1 if "SIG" in x else 0)

In [None]:
df.head()

In [None]:
df.to_parquet("data/full_sample.parquet", index=False)

# Data Visualization

In [None]:
fig = make_histogram(jet_pts, 20, col=0, clip_top_prc=100)
bin_width = fig.data[0].x[1] - fig.data[0].x[0]
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=900 * (2 / 3),
    # ignore first bin
    xaxis_range=[fig.data[0].x[1] - bin_width / 2, fig.data[0].x[-1] + bin_width / 2],
    yaxis_type="log",
)

fig.show()

In [None]:
suu_mass = {"SIG:suu": combined_masses["SIG:suu"]}
fig = make_histogram_with_double_gaussian_fit(suu_mass, 20, clip_top_prc=100, cross=None)
bin_width = fig.data[0].x[1] - fig.data[0].x[0]
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="probability density",
    barmode="stack",
    bargap=0,
    width=1300 * (2 / 3),
    height=1300 * (2 / 3),
    # ignore first bin
    xaxis_range=[fig.data[0].x[1] - bin_width / 2, fig.data[0].x[-1] + bin_width / 2],
    yaxis_type="log",
)
fig.update_legends(
    title_text="",
    itemsizing="constant",
    yanchor="top",
    y=0.1,
    xanchor="left",
    x=0.01,
    font=dict(size=16),
)
fig.show()
# fig.write_image("suu_mass.pdf")

In [None]:
df.columns

In [None]:
fig = make_histogram(jet_multiplicities, 16, clip_top_prc=100)
bin_width = fig.data[0].x[1] - fig.data[0].x[0]
fig.update_layout(
    # title="6-jet Mass",
    # xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=900 * (2 / 3),
    # ignore first bin
    xaxis_range=[fig.data[0].x[1] - bin_width / 2, fig.data[0].x[-1] + bin_width / 2],
    yaxis_type="log",
)

fig.show()
print(
    [fig.data[0].x[1] - bin_width, fig.data[0].x[-1]],
)