In [None]:
import os

import joblib
import pandas as pd
import numpy as np

import plotly.graph_objects as go

from sklearn.metrics import (
    auc,
    precision_recall_curve,
)

import diquark.constants as const
from diquark.plotting import make_histogram, make_histogram_with_double_gaussian_fit
from diquark.helpers import mass_score_cut

import tensorflow as tf

tfkl = tf.keras.layers
tfk = tf.keras

if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir("..")

In [None]:
# get the latest workdir
workdir = sorted(os.listdir("models"))[-1]
workdir = os.path.join("models", workdir)

# manually set the workdir
# workdir = "models/run_20240201131751"

## Load Data

In [None]:
df_train = pd.read_parquet(f"{workdir}/train.parquet")
df_test = pd.read_parquet(f"{workdir}/test.parquet")

In [None]:
test_df = df_test[["Truth", "inv_mass"]].reset_index(drop=True)
train_df = df_train[["Truth", "inv_mass"]].reset_index(drop=True)

m6j_test = joblib.load(f"{workdir}/m6j_test.data.joblib")
m6j_train = joblib.load(f"{workdir}/m6j_train.data.joblib")

data_npz = np.load(f"{workdir}/data.npz")
x_train, y_train, x_test, y_test = (
    data_npz["x_train"],
    data_npz["y_train"],
    data_npz["x_test"],
    data_npz["y_test"],
)

## Load Models

In [None]:
model = tfk.models.load_model(f"{workdir}/model.keras")
rf_clf = joblib.load(f"{workdir}/rfc.joblib")
gb_clf = joblib.load(f"{workdir}/gbc.joblib")

# Inference and Visualization

In [None]:
y_pred_nn = model.predict(x_test)
y_pred_gb = gb_clf.predict_proba(x_test)[:, 1]
y_pred_rf = rf_clf.predict_proba(x_test)[:, 1]

sample_weights = [const.CROSS_SECTION_ATLAS_130_85[label] for label in test_df["Truth"]]

In [None]:
# For the first model
# precision_nn, recall_nn, thresholds_nn = precision_recall_curve(y_test, y_pred_nn)
precision_nn, recall_nn, thresholds_nn = precision_recall_curve(
    y_test, y_pred_nn, sample_weight=sample_weights
)
pr_auc_nn = auc(recall_nn, precision_nn)

# For the Gradient Boosting model
# precision_gb, recall_gb, thresholds_gb = precision_recall_curve(y_test, y_pred_gb)
precision_gb, recall_gb, thresholds_gb = precision_recall_curve(
    y_test, y_pred_gb, sample_weight=sample_weights
)
pr_auc_gb = auc(recall_gb, precision_gb)

# For the Random Forest model
# precision_rf, recall_rf, thresholds_rf = precision_recall_curve(y_test, y_pred_rf)
precision_rf, recall_rf, thresholds_rf = precision_recall_curve(
    y_test, y_pred_rf, sample_weight=sample_weights
)
pr_auc_rf = auc(recall_rf, precision_rf)

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=recall_gb,
        y=precision_gb,
        customdata=thresholds_gb,
        hovertemplate="Threshold=%{customdata}<br>Recall=%{x}<br>Precision=%{y}",
        mode="lines",
        name=f"BDT - AUC={pr_auc_gb:.3f}",
    )
)
fig.add_trace(
    go.Scatter(
        x=recall_rf,
        y=precision_rf,
        customdata=thresholds_rf,
        hovertemplate="Threshold=%{customdata}<br>Recall=%{x}<br>Precision=%{y}",
        mode="lines",
        name=f"RF - AUC={pr_auc_rf:.3f}",
    )
)

fig.add_trace(
    go.Scatter(
        x=recall_nn,
        y=precision_nn,
        customdata=thresholds_nn,
        hovertemplate="Threshold=%{customdata}<br>Recall=%{x}<br>Precision=%{y}",
        mode="lines",
        name=f"NN - AUC={pr_auc_nn:.3f}",
    )
)
fig.update_layout(
    title="Cross-Section Weighted Precision-Recall Curves",
    xaxis_title="Recall",
    yaxis_title="Precision",
    width=1200 * (2 / 3),
    height=800 * (2 / 3),
    xaxis_range=[0.1, 1],
)
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
fig.write_image(f"{workdir}/plots/PR-curve-w.pdf")
fig.write_image(f"{workdir}/plots/PR-curve-w.jpg")
fig.show()

In [None]:
scores_test_rf = {}
for key in test_df["Truth"].unique():
    scores_test_rf[key] = y_pred_rf.flatten()[test_df[test_df["Truth"] == key].index]

scores_test_gb = {}
for key in test_df["Truth"].unique():
    scores_test_gb[key] = y_pred_gb.flatten()[test_df[test_df["Truth"] == key].index]

In [None]:
fig = make_histogram(scores_test_rf, 50, clip_top_prc=100, clip_bottom_prc=0, cross=None)
fig.update_layout(
    title_text="Data Sample Content by Model Output Cut",
    barmode="stack",
    yaxis_type="log",
    xaxis_title="RF Output",
    yaxis_title="Probability Density",
    width=1600 * (5 / 6),
    height=900 * (5 / 6),
)
fig.write_image(f"{workdir}/plots/RF-output.pdf")
fig.write_image(f"{workdir}/plots/RF-output.png")
fig.show()

In [None]:
fig = make_histogram(scores_test_gb, 50, clip_top_prc=100, clip_bottom_prc=0, cross=None)
fig.update_layout(
    title_text="Data Sample Content by Model Output Cut",
    barmode="stack",
    yaxis_type="log",
    xaxis_title="GB Output",
    yaxis_title="Probability Density",
    width=1600 * (5 / 6),
    height=900 * (5 / 6),
)
fig.write_image(f"{workdir}/plots/GB-output.pdf")
fig.show()

In [None]:
# Get feature importance from Random Forest
rf_importance = rf_clf.feature_importances_
gb_importance = gb_clf.feature_importances_

# Get feature names
feature_names = df_train.drop(["target", "Truth", "inv_mass"], axis=1).columns

# Create a grouped bar chart
fig = go.Figure()

# Add bars for Random Forest
fig.add_trace(
    go.Bar(
        x=feature_names,
        y=rf_importance,
        name="Random Forest",
        offsetgroup=1,
        marker=dict(color="#E4D91B"),
    )
)

# Add bars for Random Forest
fig.add_trace(
    go.Bar(
        x=feature_names,
        y=gb_importance,
        name="Gradient Boosting",
        offsetgroup=2,
        marker=dict(color="#D91BE4"),
    )
)

fig.update_layout(
    title="Feature Importances for Random Forest",
    xaxis_title="Features",
    yaxis_title="Importance Value",
    # legend_title='Classifier',
    xaxis=dict(tickangle=45),
    barmode="group",
    width=1200,
)
fig.write_image(f"{workdir}/plots/feature_importances.pdf")
fig.show()

In [None]:
# print top 10 features by importance
print("Top 10 features by importance")
print("Random Forest")
print(feature_names[np.argsort(rf_importance)[::-1][:10]])
print("Gradient Boosting")
print(feature_names[np.argsort(gb_importance)[::-1][:10]])

# vertical bar chart for random forest top 10
fig = go.Figure()
fig.add_trace(
    go.Bar(
        y=feature_names[np.argsort(rf_importance)[::-1][:10]][::-1],
        x=rf_importance[np.argsort(rf_importance)[::-1][:10]][::-1],
        name="Random Forest",
        offsetgroup=1,
        marker=dict(color="#E4D91B"),
        orientation="h",
    )
)
fig.update_layout(
    title="Top 10 Features by Importance for Random Forest",
    xaxis_title="Features",
    yaxis_title="Importance Value",
    # legend_title='Classifier',
    xaxis=dict(tickangle=45),
    barmode="group",
    width=600,
    height=800,
)
fig.write_image(f"{workdir}/plots/top10_RF.pdf")
fig.write_image(f"{workdir}/plots/top10_RF.png")
fig.show()

In [None]:
fig = make_histogram_with_double_gaussian_fit(
    mass_score_cut(m6j_test, scores_test_rf, 0.99, prc=True),
    20,
    clip_top_prc=100,
    cross=const.CROSS_SECTION_ATLAS_130_85,
)
# fig = make_histogram(mass_score_cut(m6j_test, scores_test_rf, 0.95, prc=False), 20, clip_top_prc=100)
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    # yaxis_type="log",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=800 * (2 / 3),
)
fig.update_legends(
    title_text="",
    itemsizing="constant",
    yanchor="top",
    y=0.99,
    xanchor="left",
    x=0.01,
    font=dict(size=16),
)
# fig.write_image(f"{workdir}/plots/6jet_mass_RF_cut_05_fit.pdf")
fig.show()

In [None]:
res = {}
for cut in (0.8, 0.90, 0.95, 0.99):
    scores = mass_score_cut(m6j_test, scores_test_rf, cut, prc=True)
    counts = {k: len(v) * const.CROSS_SECTION_ATLAS_130_85[k] for k, v in scores.items()}
    res[cut] = counts
df_counts_rf = pd.DataFrame(res)

In [None]:
bkg_counts = df_counts_rf.iloc[:-1].T.sum(axis=1)
sig_counts = df_counts_rf.iloc[-1]
s_over_b = sig_counts / bkg_counts
s_over_b

In [None]:
# add s_over_b as a row
df_counts_rf.loc["BKG:sum"] = bkg_counts
df_counts_rf.loc["S/B"] = s_over_b

In [None]:
df_counts_rf

In [None]:
fig = make_histogram_with_double_gaussian_fit(
    mass_score_cut(m6j_test, scores_test_gb, 0.99, prc=True),
    20,
    clip_top_prc=100,
    cross=const.CROSS_SECTION_ATLAS_130_85,
)
# fig = make_histogram(mass_score_cut(m6j_test, scores_test_gb, 0.95, prc=False), 20, clip_top_prc=100)
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    # yaxis_type="log",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=1400 * (2 / 3),
)
fig.update_legends(
    title_text="",
    itemsizing="constant",
    yanchor="top",
    y=0.99,
    xanchor="left",
    x=0.01,
    font=dict(size=16),
)
# fig.write_image(f"{workdir}/plots/6jet_mass_GB_cut_05_fit.pdf")
fig.show()

In [None]:
res = {}
for cut in (0.8, 0.90, 0.95, 0.99):
    scores = mass_score_cut(m6j_test, scores_test_gb, cut, prc=True)
    counts = {k: len(v) for k, v in scores.items()}
    res[cut] = counts
df_counts_gb = pd.DataFrame(res)
df_counts_gb

bkg_counts = df_counts_gb.iloc[:-1].T.sum(axis=1)
sig_counts = df_counts_gb.iloc[-1]
s_over_b = sig_counts / bkg_counts

# add s_over_b as a row
df_counts_gb.loc["BKG:sum"] = bkg_counts
df_counts_gb.loc["S/B"] = s_over_b