In [None]:
import os

import pendulum
import joblib
import pandas as pd
import numpy as np
from tqdm import tqdm

import plotly.express as px
import plotly.graph_objects as go

from sklearn.utils import shuffle
from sklearn.utils import resample

from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.model_selection import cross_validate, cross_val_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    auc,
    roc_curve,
    precision_recall_curve,
)

from diquark import CROSS_SECTION_DICT
from diquark.plotting import make_histogram, make_histogram_with_double_gaussian_fit
from diquark.helpers import mass_score_cut

import tensorflow as tf

tfkl = tf.keras.layers
tfk = tf.keras

if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir("..")

In [None]:
# get the latest workdir
workdir = sorted(os.listdir("models"))[-1]
workdir = os.path.join("models", workdir)

## Load Data

In [None]:
df_train = pd.read_parquet(f"{workdir}/train.parquet")
df_test = pd.read_parquet(f"{workdir}/test.parquet")

In [None]:
test_df = df_test[["Truth", "inv_mass"]].reset_index(drop=True)
train_df = df_train[["Truth", "inv_mass"]].reset_index(drop=True)

m6j_test = joblib.load(f"{workdir}/m6j_test.data.joblib")
m6j_train = joblib.load(f"{workdir}/m6j_train.data.joblib")

data_npz = np.load(f"{workdir}/data.npz")
x_train, y_train, x_test, y_test = (
    data_npz["x_train"],
    data_npz["y_train"],
    data_npz["x_test"],
    data_npz["y_test"],
)

## Load Models

In [None]:
model = tfk.models.load_model(f"{workdir}/model.keras")
rf_clf = joblib.load(f"{workdir}/rfc.joblib")
gb_clf = joblib.load(f"{workdir}/gbc.joblib")

# Inference and Visualization

In [None]:
y_pred_nn = model.predict(x_test)
y_pred_gb = gb_clf.predict_proba(x_test)[:, 1]
y_pred_rf = rf_clf.predict_proba(x_test)[:, 1]

sample_weights = [CROSS_SECTION_DICT[label] for label in test_df["Truth"]]

In [None]:
# For the first model
precision_nn, recall_nn, thresholds_nn = precision_recall_curve(y_test, y_pred_nn)
# precision_nn, recall_nn, thresholds_nn = precision_recall_curve(y_test, y_pred_nn, sample_weight=sample_weights)
pr_auc_nn = auc(recall_nn, precision_nn)

# For the Gradient Boosting model
precision_gb, recall_gb, thresholds_gb = precision_recall_curve(y_test, y_pred_gb)
# precision_gb, recall_gb, thresholds_gb = precision_recall_curve(y_test, y_pred_gb, sample_weight=sample_weights)
pr_auc_gb = auc(recall_gb, precision_gb)

# For the Random Forest model
precision_rf, recall_rf, thresholds_rf = precision_recall_curve(y_test, y_pred_rf)
# precision_rf, recall_rf, thresholds_rf = precision_recall_curve(y_test, y_pred_rf, sample_weight=sample_weights)
pr_auc_rf = auc(recall_rf, precision_rf)

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=recall_gb,
        y=precision_gb,
        customdata=thresholds_gb,
        hovertemplate="Threshold=%{customdata}<br>Recall=%{x}<br>Precision=%{y}",
        mode="lines",
        name=f"BDT - AUC={pr_auc_gb:.3f}",
    )
)
fig.add_trace(
    go.Scatter(
        x=recall_rf,
        y=precision_rf,
        customdata=thresholds_rf,
        hovertemplate="Threshold=%{customdata}<br>Recall=%{x}<br>Precision=%{y}",
        mode="lines",
        name=f"RF - AUC={pr_auc_rf:.3f}",
    )
)

fig.add_trace(
    go.Scatter(
        x=recall_nn,
        y=precision_nn,
        customdata=thresholds_nn,
        hovertemplate="Threshold=%{customdata}<br>Recall=%{x}<br>Precision=%{y}",
        mode="lines",
        name=f"NN - AUC={pr_auc_nn:.3f}",
    )
)
fig.update_layout(
    title="Cross-Section Weighted Precision-Recall Curves",
    xaxis_title="Recall",
    yaxis_title="Precision",
    width=1200 * (2 / 3),
    height=800 * (2 / 3),
)
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))
fig.write_image(f"{workdir}/plots/PR-curve-w.pdf")
fig.show()

In [None]:
scores_test_rf = {}
for key in test_df["Truth"].unique():
    scores_test_rf[key] = y_pred_rf.flatten()[test_df[test_df["Truth"] == key].index]

In [None]:
fig = make_histogram(scores_test_rf, 50, clip_top_prc=100, clip_bottom_prc=0, cross=None)
fig.update_layout(
    title_text="Data Sample Content by Model Output Cut",
    barmode="stack",
    yaxis_type="log",
    xaxis_title="RF Output",
    yaxis_title="Probability Density",
    width=1600 * (5 / 6),
    height=900 * (5 / 6),
)
fig.write_image(f"{workdir}/plots/RF-output.pdf")
fig.show()

In [None]:
fig = make_histogram_with_double_gaussian_fit(
    mass_score_cut(m6j_test, scores_test_rf, 0.5, prc=False), 20, clip_top_prc=100
)
# fig = make_histogram(mass_score_cut(m6j_test, scores_test_rf, 0.95, prc=False), 20, clip_top_prc=100)
fig.update_layout(
    title="6-jet Mass",
    xaxis_title="Invariant Mass [GeV]",
    yaxis_title_text="count x sigma",
    # yaxis_type="log",
    barmode="stack",
    bargap=0,
    width=1600 * (2 / 3),
    height=1400 * (2 / 3),
)
fig.update_legends(
    title_text="",
    itemsizing="constant",
    yanchor="top",
    y=0.99,
    xanchor="left",
    x=0.01,
    font=dict(size=16),
)
# fig.write_image(f"{workdir}/plots/6jet_mass_RF_cut_05_fit.pdf")
fig.show()