In [None]:
import os

import pandas as pd
import numpy as np
import json

import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm

import diquark.constants as const
from diquark.plotting import make_histogram, make_histogram_with_double_gaussian_fit
from diquark.helpers import mass_score_cut

import tensorflow as tf

tfkl = tf.keras.layers
tfk = tf.keras

if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir("..")

In [None]:
# get the latest workdir
PREFIXES = [
    "run_ATLAS_130_65",
    "run_ATLAS_130_80",
    "run_ATLAS_136_65",
    "run_ATLAS_136_80",
    "run_CMS_130_65",
    "run_CMS_130_80",
    "run_CMS_136_65",
    "run_CMS_136_80",
]


def get_workdir(prefix="run_ATLAS_130_65"):
    return sorted([d for d in os.listdir("models") if d.startswith(prefix)])[-1]


data = []
fits = []
pr = []
for prefix in tqdm(PREFIXES):
    workdir = get_workdir(prefix)
    df = pd.read_csv(f"models/{workdir}/counts_rf.csv")
    df = df.melt(id_vars=["Unnamed: 0"], var_name="cut", value_name="counts")
    df["detector"] = prefix.split("_")[1]
    df["ECM"] = int(prefix.split("_")[2]) / 10
    df["mSuu"] = (int(prefix.split("_")[3]) / 10) + 0.5
    data.append(df)

    with open(f"models/{workdir}/fits.json", "r") as f:
        js = json.load(f)
        js["detector"] = prefix.split("_")[1]
        js["ECM"] = int(prefix.split("_")[2]) / 10
        js["mSuu"] = (int(prefix.split("_")[3]) / 10) + 0.5
        fits.append(js)

    with open(f"models/{workdir}/pr_curves.json", "r") as f:
        js = json.load(f)
        js["detector"] = prefix.split("_")[1]
        js["ECM"] = int(prefix.split("_")[2]) / 10
        js["mSuu"] = (int(prefix.split("_")[3]) / 10) + 0.5
        pr.append(js)


counts_df = pd.concat(data).rename(columns={"Unnamed: 0": "label"})
fits_df = pd.json_normalize(fits)
pr_df = pd.json_normalize(pr)

In [None]:
counts_df.to_csv("models/counts.csv", index=False)

In [None]:
fig = px.line(
    counts_df[counts_df["label"] == "S/B"],
    x="cut",
    y="counts",
    color="detector",
    symbol="ECM",
    line_dash="ECM",
    facet_row="mSuu",
    markers=True,
)
fig.update_yaxes(type="log")
fig.for_each_yaxis(lambda a: a.title.update(text="S/B"))
fig.update_layout(
    title="S/B vs. discriminator cut for different detectors and ECMs",
    width=600,
    height=600,
)
# fig.update_xaxes(title="cut")
# make y-axes be named "S/B"
fig.write_image("sb_vs_cut.pdf")
fig.show()
# fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

In [None]:
long_df = pd.wide_to_long(
    pr_df.reset_index(),
    stubnames=["precision", "recall", "threshold", "auc"],
    i="index",
    j="model",
    sep="_",
    suffix=r"\w+",
).reset_index()

# Now, extract the model type into a new column
long_df["model"] = long_df["model"].map({"nn": "NN", "gb": "GB", "rf": "RF"})
long_df = long_df.drop(columns=["index"])
# Optionally, sort or reorganize the DataFrame as needed
# long_df

In [None]:
long_df

In [None]:
fig = px.bar(
    long_df,
    x="detector",
    y="auc",
    color="model",
    barmode="group",
    facet_row="mSuu",
    facet_col="ECM",
)
fig.update_yaxes(type="log")
fig.update_layout(
    title="AUC for different detectors and ECMs",
)
fig.show()

In [None]:
fits_df.__str__()

In [None]:
# Create a Plotly figure
fig = go.Figure()

# Prepare data for plotting
models = ["nn", "rf", "gb"]
colors = ["blue", "green", "red"]
# x = range(len(fits_df))


# Plot each model with error bars
for model, color in zip(models, colors):
    mean = fits_df[f"{model}.mean"]
    std = fits_df[f"{model}.std"]
    fig.add_trace(
        go.Scatter(
            x=fits_df["mSuu"].astype(str)
            + " TeV, ECM="
            + fits_df["ECM"].astype(str)
            + " TeV"
            + " ("
            + fits_df["detector"]
            + ")",
            y=mean,
            error_y=dict(type="data", array=std, visible=True),
            mode="markers",
            name=f"{model.upper()} mean ± std",
            marker=dict(color=color),
        )
    )

# Update layout
fig.update_layout(
    title="m6j fit Mean and Standard Deviation by Model and mSuu",
    xaxis_title="mSuu (Detector)",
    yaxis_title="Mean Value",
    legend_title="Model",
    xaxis=dict(tickangle=45),
)

fig.show()

In [None]:
ecms = counts_df["ECM"].unique()
msuus = counts_df["mSuu"].unique()
detectors = counts_df["detector"].unique()

msuus

In [None]:
def generate_latex_table(df, luminosity=3000):
    # Extract unique values for ECM, mSuu, and detector to structure the table
    ecms = df["ECM"].unique()
    msuus = df["mSuu"].unique()
    detectors = df["detector"].unique()
    df["cut"] = df["cut"].astype(float)

    # Start the LaTeX table format
    latex_str = "begin{tabular}{|l|c|c|c|c|}\n\\hline\n"
    latex_str += " & \\multicolumn{4}{c|}{Counts for cuts} \\\\ \\cline{2-5}\n"
    latex_str += "Event & 0.80 & 0.90 & 0.95 & 0.99 \\\\ \\hline\n"

    # Loop over each combination of ECM, mSuu, and detector
    for detector in detectors:
        for ecm in ecms:
            for msuu in msuus:
                # Filter data for current combination
                filtered_data = df[
                    (df["ECM"] == ecm) & (df["mSuu"] == msuu) & (df["detector"] == detector)
                ]
                if not filtered_data.empty:
                    # For each cut, find the sum of BKG, SIG, and calculate S/B
                    for cut in [0.8, 0.90, 0.95, 0.99]:
                        bkg_sum = (
                            filtered_data[
                                (filtered_data["label"] == "BKG:sum")
                                & (filtered_data["cut"] == cut)
                            ]["counts"].values[0]
                            * luminosity
                        )
                        sig_sum = (
                            filtered_data[
                                (filtered_data["label"] == "SIG:Suu")
                                & (filtered_data["cut"] == cut)
                            ]["counts"].values[0]
                            * luminosity
                        )
                        sb_ratio = filtered_data[
                            (filtered_data["label"] == "S/B") & (filtered_data["cut"] == cut)
                        ]["counts"].values[0]
                        # Add to LaTeX string
                        latex_str += f"{detector}, ECM={ecm}, mSuu={msuu} & {bkg_sum:.2e} & {sig_sum:.2e} & {sb_ratio:.2e} \\\\ \\hline\n"

    # Close the LaTeX table format
    latex_str += "\\end{tabular}"

    return latex_str

In [None]:
counts_df

In [None]:
print(generate_latex_table(counts_df))

In [None]:
def generate_latex_table_with_multirow(df, luminosity=3000):
    # Ensure 'cut' is of type float
    df["cut"] = df["cut"].astype(float)

    # Initialize the LaTeX table format with headers for each distinct category
    latex_str = "\\begin{tabular}{|l|l|l|c|c|c|c|}\n\\hline\n"
    latex_str += "Detector & ECM & mSuu & \\multicolumn{4}{c|}{Counts for cuts} \\\\ \\cline{4-7}\n"
    latex_str += " & & & 0.80 & 0.90 & 0.95 & 0.99 \\\\ \\hline\n"

    # Group by detector, ECM, mSuu to count occurrences
    grouped = df.groupby(["detector", "ECM", "mSuu"])

    for (detector, ecm, msuu), group in grouped:
        num_rows = len(group["cut"].unique())
        first_row = True  # Indicator for the first row of each group

        for cut in [0.8, 0.90, 0.95, 0.99]:
            bkg_sum = sig_sum = sb_ratio = 0  # Default values
            if cut in group["cut"].values:
                # Calculate values only if the cut exists in the group
                bkg_sum = (
                    group[(group["label"] == "BKG:sum") & (group["cut"] == cut)]["counts"].values[0]
                    * luminosity
                )
                sig_sum = (
                    group[(group["label"] == "SIG:Suu") & (group["cut"] == cut)]["counts"].values[0]
                    * luminosity
                )
                sb_ratio = group[(group["label"] == "S/B") & (group["cut"] == cut)][
                    "counts"
                ].values[0]

            if first_row:
                # Use \multirow for the first row in each group
                latex_str += f"\\multirow{{{num_rows}}}{{*}}{{{detector}}} & \\multirow{{{num_rows}}}{{*}}{{{ecm}}} & \\multirow{{{num_rows}}}{{*}}{{{msuu}}} & {cut} & {bkg_sum:.2e} & {sig_sum:.2e} & {sb_ratio:.2e} \\\\ \\cline{4-7}\n"
                first_row = False
            else:
                # Only include cut-specific data in subsequent rows
                latex_str += f" & & & {cut} & {bkg_sum:.2e} & {sig_sum:.2e} & {sb_ratio:.2e} \\\\ \\cline{4-7}\n"

    # Close the LaTeX table format
    latex_str += "\\end{tabular}"

    return latex_str

In [None]:
print(generate_latex_table_with_multirow(counts_df))

In [None]:
def generate_latex_table_with_multirow_corrected(df, luminosity=3000):
    # Convert 'cut' to float to ensure proper sorting and comparison
    df["cut"] = df["cut"].astype(float)

    # Initialize the LaTeX table format with headers
    latex_str = "\\begin{tabular}{|l|l|l|c|c|c|c|}\n\\hline\n"
    latex_str += (
        "Detector & ECM & mSuu & \\multicolumn{4}{c|}{Event Counts for cuts} \\\\ \\cline{4-7}\n"
    )
    latex_str += " & & & cut & Background & Signal & S/B \\\\ \\hline\n"

    # Calculate the number of rows needed for the multirow command dynamically
    detector_group = df.groupby(["detector"])
    for detector, detector_df in detector_group:
        detector_rows = len(detector_df.groupby(["ECM", "mSuu"]).size())
        ecm_group = detector_df.groupby(["ECM"])
        for ecm, ecm_df in ecm_group:
            ecm_rows = len(ecm_df.groupby(["mSuu"]).size())
            msuu_group = ecm_df.groupby(["mSuu"])
            for msuu, msuu_df in msuu_group:
                msuu_rows = len(msuu_df["cut"].unique())
                first_row = True
                for _, row in msuu_df.iterrows():
                    cut = row["cut"]
                    bkg_sum = row["counts"] if row["label"] == "BKG:sum" else 0
                    sig_sum = row["counts"] if row["label"] == "SIG:Suu" else 0
                    sb_ratio = row["counts"] if row["label"] == "S/B" else 0
                    if first_row:
                        latex_str += f"\\multirow{{{detector_rows}}}{{*}}{{{detector}}} & \\multirow{{{ecm_rows}}}{{*}}{{{ecm}}} & \\multirow{{{msuu_rows}}}{{*}}{{{msuu}}} & {cut} & {bkg_sum:.2e} & {sig_sum:.2e} & {sb_ratio:.2e} \\\\\n"
                        first_row = False
                    else:
                        latex_str += (
                            f" & & & {cut} & {bkg_sum:.2e} & {sig_sum:.2e} & {sb_ratio:.2e} \\\\\n"
                        )
                detector_rows -= msuu_rows
                if detector_rows > 0:  # If there are more rows to cover for this detector
                    latex_str += "\\cline{2-7}\n"
            if ecm_rows > msuu_rows:  # Adjust for next ECM if necessary
                latex_str += "\\cline{3-7}\n"

    latex_str += "\\hline\n\\end{tabular}"

    return latex_str


# Assuming 'df' is your DataFrame
latex_table_corrected = generate_latex_table_with_multirow_corrected(counts_df)
print(latex_table_corrected)

In [None]:
df_0 = counts_df[
    (counts_df["ECM"] == 13.6) & (counts_df["mSuu"] == 7.0) & (counts_df["detector"] == "ATLAS")
][["label", "cut", "counts"]]
df_0.loc[df_0["label"] != "S/B", "counts"]  # *= 3000 / 100_000

In [None]:
df_0

In [None]:
df_0.pivot(index="label", columns="cut", values="counts")

In [None]:
For an integrated luminosity of $3000 fb^{-1}$ expected by the end of the HL-LHC phase, we are able to isolate high purity samples $S/B \ge 100$ containing hundreds of events.