In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Dropout, Flatten
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve
import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")


2025-04-25 08:51:03.767706: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-25 08:51:03.767742: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-25 08:51:03.768863: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-25 08:51:03.775251: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


2025-04-25 08:51:15.030910: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38366 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:03:00.0, compute capability: 8.0
2025-04-25 08:51:15.032638: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 38366 MB memory:  -> device: 1, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:41:00.0, compute capability: 8.0
2025-04-25 08:51:15.034408: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 38366 MB memory:  -> device: 2, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:82:00.0, compute capability: 8.0
2025-04-25 08:51:15.037181: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 38366 MB memory:  -> device: 3, name: NVIDIA A100-SXM4-40GB, pci bu

Number of devices: 4


In [2]:
###############################################################################
# 1. Plot Overlapping Histograms (Signal vs. Background)
###############################################################################
def plot_signal_vs_background(df: pd.DataFrame, feature: str = "mj2", bins=50):
    """
    Plot overlapping histograms of a given feature for signal vs. background,
    normalized to unit area.
    """
    df_signal = df[df["label"] == 1]
    df_background = df[df["label"] == 0]

    plt.figure(figsize=(6, 4))
    plt.hist(df_signal[feature], bins=bins, density=True,
             histtype='step', label='signal')
    plt.hist(df_background[feature], bins=bins, density=True,
             histtype='step', label='background')

    plt.xlabel(f"{feature} [TeV]")
    plt.ylabel("Fraction of Events / bin")
    plt.legend()
    plt.tight_layout()
    plt.show()

In [3]:
###############################################################################
# 2. Load from HDF5: ensures m, mx, my, tau12, tau23, label exist
###############################################################################
def load_m_mx_my_tau12_tau23(
    file_path: str,
    key: str = "/df",
    scale_to_tev: bool = False,
    assign_label: int = None
) -> pd.DataFrame:
    """
    Reads HDF5 (file_path) under group 'key' into a DataFrame.
    - If 'assign_label' is an integer, sets 'label' to that.
    - If 'scale_to_tev' is True, multiply masses by 0.001.
    - Ensures columns: mj1, mj2, tau12j1, tau23j1, tau12j2, tau23j2, mx, my, label.
    """
    df = pd.read_hdf(file_path, key=key)

    # Assign label 
    if assign_label is not None:
        df["label"] = assign_label

    # scale from GeV to TeV
    scale = 0.001 if scale_to_tev else 1.0

    for j in ["j1", "j2"]:
        mcol = f"m{j}"
        if mcol not in df.columns:
            df[mcol] = 0.0
        df[mcol] *= scale

        for base in ["tau1", "tau2", "tau3"]:
            if f"{base}{j}" not in df.columns:
                df[f"{base}{j}"] = 0.0

        tau1 = df[f"tau1{j}"]
        tau2 = df[f"tau2{j}"]
        tau3 = df[f"tau3{j}"]
        mask_12 = (tau1 > 0) & (tau2 > 0)
        mask_23 = (tau2 > 0) & (tau3 > 0)
        df[f"tau12{j}"] = np.where(mask_12, tau2 / tau1, 0.0)
        df[f"tau23{j}"] = np.where(mask_23, tau3 / tau2, 0.0)

    if "mx" not in df.columns:
        df["mx"] = 0.0
    if "my" not in df.columns:
        df["my"] = 0.0

    needed_cols = [
        "mj1", "tau12j1", "tau23j1",
        "mj2", "tau12j2", "tau23j2",
        "mx", "my"
    ]
    if "label" in df.columns:
        needed_cols.append("label")

    
    for c in needed_cols:
        if c not in df.columns:
            df[c] = 0.0

    return df[needed_cols].copy()

In [4]:
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
from aliad.interface.tensorflow.losses import ScaledBinaryCrossentropy  
import numpy as np

def base_highlevel_model(name="BaseModel", input_name="input_features", input_shape=(2, 3), use_scaled_bce=False):
    inputs = Input(shape=input_shape, name=input_name)
    x = Flatten()(inputs)

    x = Dense(256, activation="relu")(x)
    x = Dropout(0.3)(x)
    x = Dense(128, activation="relu")(x)
    x = Dense(64, activation="relu")(x)
    outputs = Dense(1, activation="sigmoid")(x)

    model = Model(inputs, outputs, name=name)

    if use_scaled_bce:
        loss_fn = ScaledBinaryCrossentropy(offset=-np.log(2), scale=1000)
    else:
        loss_fn = "binary_crossentropy"

    model.compile(optimizer=Adam(learning_rate=1e-3), loss=loss_fn, metrics=["accuracy"])
    return model

def simple_supervised_model():
    return base_highlevel_model(name="SimpleSupervisedModel", input_name="input_features", use_scaled_bce=False)

def ideal_weakly_supervised_model():
    return base_highlevel_model(name="IdealWeaklySupervisedModel", input_name="ideal_weakly_input", use_scaled_bce=True)



In [5]:
###############################################################################
# 4. Exclude Specific Mass Points
###############################################################################
def exclude_mass_points(df: pd.DataFrame, mass_pairs_to_exclude):
    """
    Splits 'df' into included vs. excluded data based on (mx, my).
    """
    mass_set = set(mass_pairs_to_exclude)
    df["excluded"] = df.apply(lambda r: (r["mx"], r["my"]) in mass_set, axis=1)
    df_excluded = df[df["excluded"]].copy()
    df_included = df[~df["excluded"]].copy()

    df.drop(columns=["excluded"], inplace=True)
    df_included.drop(columns=["excluded"], inplace=True)
    df_excluded.drop(columns=["excluded"], inplace=True)
    return df_included, df_excluded

In [6]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

def plot_mj2_histograms(df_signal, df_background):
    mj2_signal = df_signal["mj2"].values
    mj2_background = df_background["mj2"].values
    bins = np.linspace(0, 1, 1000)
    plt.hist(mj2_signal, bins=bins, density=True, histtype='step', color='black', label='signal', linewidth=1.5)
    plt.hist(mj2_background, bins=bins, density=True, histtype='stepfilled', color='crimson', alpha=0.5, label='background')
    plt.xlabel(r'$m_{j2}$ [TeV]', fontsize=14)
    plt.ylabel(f'Fraction of Events / {np.round(np.diff(bins)[0], 5)}', fontsize=14)
    plt.legend(frameon=False, fontsize=12)
    plt.tight_layout()
    plt.grid(False)
    plt.show()

def compute_sic(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    with np.errstate(divide='ignore', invalid='ignore'):
        sic = tpr / np.sqrt(fpr)
        sic[np.isnan(sic)] = 0.0
        sic[np.isinf(sic)] = 0.0
    return thresholds, sic

def plot_sic_curves_grouped(models_and_preds, title_prefix):
    plt.figure(figsize=(8, 6))
    for label, y_true, y_score in models_and_preds:
        thresholds, sic = compute_sic(y_true, y_score)
        plt.plot(thresholds, sic, label=label)
        max_idx = np.argmax(sic)
        plt.scatter(thresholds[max_idx], sic[max_idx], marker="x", color="black")
        plt.text(thresholds[max_idx], sic[max_idx], f"{sic[max_idx]:.2f}", fontsize=8)
    plt.title(f"SIC Curve — {title_prefix}")
    plt.xlabel("NN Score Threshold")
    plt.ylabel("SIC (TPR / √FPR)")
    plt.legend()
    plt.tight_layout()
    plt.show()

def train_and_predict(model_fn, model_key, df_signal, df_background, df_test_true, epochs=10, is_weakly=False, signal_fraction=0.001):
    #model = model_fn()
    with strategy.scope():
        model = model_fn()
        model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    outdir = "./my_trained_models"
    os.makedirs(outdir, exist_ok=True)
    best_path = os.path.join(outdir, f"best_{model_key}.keras")

    early_stopping = EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True, verbose=1)
    checkpoint = ModelCheckpoint(filepath=best_path, monitor="val_loss", save_best_only=True, verbose=1)

    df_true_test = df_test_true.sample(frac=1.0, random_state=42)
    X_test_true = df_true_test[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
    y_test_true = df_true_test["label"].values.astype(np.float32)

    if is_weakly:
        df_bg = df_background.sample(frac=1.0, random_state=42).reset_index(drop=True)
        half = len(df_bg) // 2
        R = df_bg.iloc[:half].copy()
        D_bg = df_bg.iloc[half:].copy()
        n_signal = int(len(D_bg) * signal_fraction)
        D_sig = df_signal.sample(n=n_signal, random_state=42)
        D = pd.concat([D_bg, D_sig], ignore_index=True)
        R["label"] = 0
        D["label"] = 1
        df_weak = pd.concat([R, D], ignore_index=True).sample(frac=1.0, random_state=42)
        X_weak = df_weak[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y_weak = df_weak["label"].values.astype(np.float32)

        model.fit(X_weak, y_weak,
                  validation_split=0.20,
                  epochs=epochs, batch_size=64,
                  callbacks=[early_stopping, checkpoint], verbose=1)

    else:
        df = pd.concat([df_signal, df_background], ignore_index=True)
        X = df[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y = df["label"].values.astype(np.float32)

        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.20, stratify=y, random_state=42)

        model.fit(X_train, y_train,
                  validation_data=(X_val, y_val),
                  epochs=epochs, batch_size=64,
                  callbacks=[early_stopping, checkpoint], verbose=1)

    y_pred = model.predict(X_test_true).flatten()
    return y_test_true, y_pred

def main():
    qcd_file = "/global/cfs/projectdirs/m3246/alkaid/paws/datasets/original/events_anomalydetection_v2.features.h5"
    extra_qcd_file = "/global/cfs/projectdirs/m3246/alkaid/paws/datasets/original/events_anomalydetection_qcd_extra_inneronly_features.h5"
    w_qq_file = "/global/cfs/projectdirs/m3246/alkaid/paws/datasets/original/events_anomalydetection_Z_XY_qq_parametric.h5"

    df_qcd = load_m_mx_my_tau12_tau23(qcd_file, key="/df", scale_to_tev=True)
    df_extra = load_m_mx_my_tau12_tau23(extra_qcd_file, key="/df", scale_to_tev=True)
    df_signal = load_m_mx_my_tau12_tau23(w_qq_file, key="/output", scale_to_tev=True)

    df_qcd["label"] = 0
    df_extra["label"] = 0
    df_signal["label"] = 1

    df_background = pd.concat([df_qcd, df_extra], ignore_index=True)

    signal_300 = df_signal.query("mx == 300 and my == 300")
    signal_100_500 = df_signal.query("mx == 100 and my == 500")
    signal_500_100 = df_signal.query("mx == 500 and my == 100")

    df_signal_general = df_signal.query("not ((mx == 300 and my == 300) or \
                                              (mx == 100 and my == 500) or \
                                              (mx == 500 and my == 100))")

    background_test = pd.concat([
        df_qcd.sample(n=2500, random_state=1),
        df_extra.sample(n=2500, random_state=2)
    ])
    testset_300 = pd.concat([background_test, signal_300.sample(n=2500, random_state=3)])
    testset_100_500 = pd.concat([background_test, signal_100_500.sample(n=2500, random_state=4)])

    model_defs = [
        ("Simple Supervised (All)", simple_supervised_model, "Simple_All", df_signal_general, False),
        ("Simple Supervised (300_300)", simple_supervised_model, "Simple_300_300", signal_300, False),
        ("Simple Supervised (100_500)", simple_supervised_model, "Simple_100_500", signal_100_500, False),
    ]

    weakly_defs = [
        ("Ideal Weakly (300_300)", ideal_weakly_supervised_model, "Weakly_300_300", signal_300, True),
        ("Ideal Weakly (100_500)", ideal_weakly_supervised_model, "Weakly_100_500", signal_100_500, True),
    ]

    results = []

    signal_fracs = [0.001, 0.01, 0.005]

    for label, model_fn, key, signal_df, is_weak in model_defs:
        # Train once
        _ = train_and_predict(model_fn, key, signal_df, df_background, testset_300, is_weakly=is_weak)

        # Load trained weights
        model = model_fn()
        model.load_weights(f"./my_trained_models/best_{key}.keras")

        # Predict on testset_300
        X_300 = testset_300[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y_300 = testset_300["label"].values.astype(np.float32)
        y_pred_300 = model.predict(X_300).flatten()
        results.append((f"{label} on 300_300", y_300, y_pred_300))

        # Predict on testset_100_500
        X_100_500 = testset_100_500[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y_100_500 = testset_100_500["label"].values.astype(np.float32)
        y_pred_100_500 = model.predict(X_100_500).flatten()
        results.append((f"{label} on 100_500", y_100_500, y_pred_100_500))

    for signal_frac in signal_fracs:
        for label, model_fn, key, signal_df, is_weak in weakly_defs:
            key_frac = f"{key}_frac{signal_frac:.3f}".replace("0.", "")
            label_frac = f"{label} μ={signal_frac:.3f}"

            y_true_300, y_pred_300 = train_and_predict(model_fn, key_frac, signal_df, df_background, testset_300, is_weakly=is_weak, signal_fraction=signal_frac)
            results.append((f"{label_frac} on 300_300", y_true_300, y_pred_300))

            y_true_100_500, y_pred_100_500 = train_and_predict(model_fn, key_frac, signal_df, df_background, testset_100_500, is_weakly=is_weak, signal_fraction=signal_frac)
            results.append((f"{label_frac} on 100_500", y_true_100_500, y_pred_100_500))

    plot_sic_curves_grouped(
    [r for r in results if "on 300_300" in r[0] and not ("Ideal Weakly (100_500)" in r[0])],
    title_prefix="Mass Point (300, 300)")

    plot_sic_curves_grouped(
    [r for r in results if "on 100_500" in r[0] and not ("Ideal Weakly (300_300)" in r[0])],
    title_prefix="Mass Point (100, 500)")


    # === Start of Figure 2-style panel ===
    injection_levels = np.arange(0.001, 0.011, 0.001)
    mass_points = [(100, 500), (300, 300)]
    results_dict = {}

    for mx, my in mass_points:
        if (mx, my) == (100, 500):
            signal_df = signal_100_500
            test_df = testset_100_500
        elif (mx, my) == (300, 300):
            signal_df = signal_300
            test_df = testset_300

        results_dict[(mx, my)] = {"mu": [], "sic_median": [], "sic_std": [], "param_median": [], "param_std": []}

        for mu in injection_levels:
            sic_vals = []
            param_vals = []

            for seed in range(5):  # 5 retrainings
                np.random.seed(seed)
                tf.random.set_seed(seed)

                model = ideal_weakly_supervised_model()
                y_true, y_pred = train_and_predict(
                    model_fn=ideal_weakly_supervised_model,
                    model_key=f"mx{mx}_my{my}_mu{mu:.3f}_seed{seed}",
                    df_signal=signal_df,
                    df_background=df_background,
                    df_test_true=test_df,
                    is_weakly=True,
                    signal_fraction=mu
                )

                # Compute SIC at fixed εB = 0.001
                fpr, tpr, _ = roc_curve(y_true, y_pred)
                try:
                    idx = np.where(fpr >= 0.001)[0][0]
                    sic = tpr[idx] / np.sqrt(fpr[idx])
                except IndexError:
                    sic = 0
                sic_vals.append(sic)

                # Summary stat — here: mean of NN output
                param_vals.append(np.mean(y_pred))

            results_dict[(mx, my)]["mu"].append(mu)
            results_dict[(mx, my)]["sic_median"].append(np.median(sic_vals))
            results_dict[(mx, my)]["sic_std"].append(np.std(sic_vals))
            results_dict[(mx, my)]["param_median"].append(np.median(param_vals))
            results_dict[(mx, my)]["param_std"].append(np.std(param_vals))

    # === Plot Figure 2-style panel ===
    fig, axs = plt.subplots(2, 2, figsize=(12, 8), sharex=True)

    for i, (mx, my) in enumerate(mass_points):
        data = results_dict[(mx, my)]
        mu = np.array(data["mu"])

        # Top: SIC
        ax_top = axs[0, i]
        ax_top.plot(mu, data["sic_median"], label='Ideal Weak', color='blue')
        ax_top.fill_between(mu,
                            np.array(data["sic_median"]) - np.array(data["sic_std"]),
                            np.array(data["sic_median"]) + np.array(data["sic_std"]),
                            alpha=0.3, color='blue')
        ax_top.set_title(f"(mX, mY) = ({mx}, {my})")
        ax_top.set_ylabel("SIC at εB = 0.1%")
        ax_top.grid(True)

        # Bottom: Predicted param (NN output in this case)
        ax_bot = axs[1, i]
        ax_bot.plot(mu, data["param_median"], label='Predicted Value', color='green')
        ax_bot.fill_between(mu,
                            np.array(data["param_median"]) - np.array(data["param_std"]),
                            np.array(data["param_median"]) + np.array(data["param_std"]),
                            alpha=0.3, color='green')
        ax_bot.set_xlabel("Signal Injection Fraction μ")
        ax_bot.set_ylabel("NN Output (avg)")
        ax_bot.axhline(1.0, linestyle="--", color="gray", label="True Value")
        ax_bot.grid(True)

    axs[0, 0].legend()
    axs[1, 0].legend()
    plt.tight_layout()
    plt.show()
def main_figure3_style_side_by_side():
    df_qcd = load_m_mx_my_tau12_tau23(
        "/global/cfs/projectdirs/m3246/alkaid/paws/datasets/original/events_anomalydetection_v2.features.h5",
        key="/df", scale_to_tev=True
    )
    df_extra = load_m_mx_my_tau12_tau23(
        "/global/cfs/projectdirs/m3246/alkaid/paws/datasets/original/events_anomalydetection_qcd_extra_inneronly_features.h5",
        key="/df", scale_to_tev=True
    )
    df_signal = load_m_mx_my_tau12_tau23(
        "/global/cfs/projectdirs/m3246/alkaid/paws/datasets/original/events_anomalydetection_Z_XY_qq_parametric.h5",
        key="/output", scale_to_tev=True
    )

    df_qcd["label"] = 0
    df_extra["label"] = 0
    df_signal["label"] = 1
    df_background = pd.concat([df_qcd, df_extra], ignore_index=True)

    background_test = pd.concat([
        df_qcd.sample(n=2500, random_state=1),
        df_extra.sample(n=2500, random_state=2)
    ])
    signal_300 = df_signal.query("mx == 300 and my == 300")
    signal_100_500 = df_signal.query("mx == 100 and my == 500")

    testsets = {
        (300, 300): pd.concat([background_test, signal_300.sample(n=2500, random_state=3)]),
        (100, 500): pd.concat([background_test, signal_100_500.sample(n=2500, random_state=4)])
    }

    signal_sets = {
        (300, 300): signal_300,
        (100, 500): signal_100_500
    }

    fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

    for i, (mx, my) in enumerate([(300, 300), (100, 500)]):
        ax = axs[i]
        testset = testsets[(mx, my)]
        sig_df = signal_sets[(mx, my)]
        X_test = testset[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y_true = testset["label"].values.astype(np.float32)

        # === Train simple supervised model ===
        df = pd.concat([sig_df, df_background], ignore_index=True)
        X = df[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y = df["label"].values.astype(np.float32)
        X_train, X_val, y_train, y_val = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)

        model_sup = simple_supervised_model()
        model_sup.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=10, batch_size=512, verbose=0)
        y_pred_sup = model_sup.predict(X_test).flatten()

        fpr, tpr, _ = roc_curve(y_true, y_pred_sup)
        with np.errstate(divide='ignore', invalid='ignore'):
            sic = tpr / np.sqrt(fpr)
            sic[np.isnan(sic)] = 0.0
            sic[np.isinf(sic)] = 0.0
        ax.plot(tpr, sic, label="Simple Supervised")

        # === Train ideal weakly supervised model ===
        df_bg = df_background.sample(frac=1.0, random_state=42).reset_index(drop=True)
        half = len(df_bg) // 2
        R = df_bg.iloc[:half].copy()
        D_bg = df_bg.iloc[half:].copy()
        n_signal = int(len(D_bg) * 0.001)
        D_sig = sig_df.sample(n=n_signal, random_state=42)
        D = pd.concat([D_bg, D_sig], ignore_index=True)
        R["label"] = 0
        D["label"] = 1
        df_weak = pd.concat([R, D], ignore_index=True).sample(frac=1.0, random_state=42)
        X_weak = df_weak[["mj1", "tau12j1", "tau23j1", "mj2", "tau12j2", "tau23j2"]].values.reshape(-1, 2, 3)
        y_weak = df_weak["label"].values.astype(np.float32)

        model_weak = ideal_weakly_supervised_model()
        model_weak.fit(X_weak, y_weak, validation_split=0.2, epochs=10, batch_size=512, verbose=0)
        y_pred_weak = model_weak.predict(X_test).flatten()

        fpr, tpr, _ = roc_curve(y_true, y_pred_weak)
        with np.errstate(divide='ignore', invalid='ignore'):
            sic = tpr / np.sqrt(fpr)
            sic[np.isnan(sic)] = 0.0
            sic[np.isinf(sic)] = 0.0
        ax.plot(tpr, sic, label="Ideal Weakly Supervised", linestyle="--")

        ax.set_title(f"(mX, mY) = ({mx}, {my})", fontsize=13)
        ax.set_xlabel("Signal Efficiency", fontsize=12)
        if i == 0:
            ax.set_ylabel("Significance Improvement", fontsize=12)
        ax.legend()
        ax.grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
if __name__ == "__main__":
    main()
    main_figure3_style_side_by_side()

Epoch 1/10
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/rep

2025-04-25 08:51:35.019643: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f43544365a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-04-25 08:51:35.019664: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2025-04-25 08:51:35.019676: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (1): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2025-04-25 08:51:35.019681: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (2): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2025-04-25 08:51:35.019694: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (3): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2025-04-25 08:51:35.029368: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-04-25 08:5

  1864/170510 [..............................] - ETA: 11:52 - loss: 0.2398 - accuracy: 0.9104    