## Metrics (Training Results)

- This notebook shows the training results of different models, including CNN and Particle Transformer.
- Most of the training results are repeated 10 times with different random seeds.
- The signal and background were set to be Higgs from VBF and GGF, respectively.

In [1]:
from itertools import product

import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import seaborn as sns

sns.set_theme()
default_colors = sns.color_palette()
FRAMEON = False

# Define the root of the project
ROOT = Path.cwd().parent

def get_metrics(channel: str, model: str, data_mode: str, date_time: str, data_suffix: str = '', num_rnd: int = None):
    """Print summary metrics for different models under given configuration."""

    # Define path to metrics output
    output_dir = ROOT / 'output'

    # Loop through models
    df = pd.DataFrame()

    # Collect metrics from each random seed run
    rnd_seed = 123

    for _ in range(num_rnd):
        metrics_file = Path(channel) / data_mode / f"{model}-{date_time}-{data_suffix}" / f'rnd_seed-{rnd_seed}' / 'metrics.csv'
        if not (output_dir / metrics_file).exists():
            print(f"Warning: Metrics file {metrics_file} does not exist. Stopping further checks for this model.")
            break

        df_tmp = pd.read_csv(output_dir / metrics_file)
        test_metrics = df_tmp.tail(1)
        df = pd.concat([df, pd.DataFrame({
            'model': model,
            'rnd_seed': rnd_seed,
            'test_accuracy': test_metrics['test_accuracy'].values,
            'test_auc': test_metrics['test_auc'].values,
            'epoch': test_metrics['epoch'].values
        })], ignore_index=True)

        rnd_seed += 100

    return df

## $H \rightarrow \gamma \gamma$

#### Keras-like settings

- 20251001_052410
- Keras parameter initialization
- Keras batch normalizaiton in CNN

#### Supervised
- 20251014_102416

#### PyTorch default settings (240K ParT + remove only decay product)

- 20250930_105915 : +0
- 20250923_232355 : +5
- 20250924_111848 : +10

#### PyTorch default settings (9.5K ParT + remove also neighbors nearby decay product)

- 20251005_154731 : +0
- 20251006_114628 : +5
- 20251007_015709 : +10

#### More lighter ParT (2.7K)
- 20251008_004504
- 20251008_150811
- 20251009_024328

## $H\to ZZ\to 4l$

- 20251011_225958 (+0)
- 20251013_003505 (+5)
- 20251011_230612 (+10)

#### Test more data with $\mathcal{L} \in \{9000, 18000, 30000\}$
- 20251015_131405

In [2]:
df = pd.DataFrame()

diphoton_info_list = [
    ('CNN_EventCNN', '20251005_154731', '+0'),
    ('CNN_EventCNN', '20251006_114628', '+5'),
    ('CNN_EventCNN', '20251007_015709', '+10'),

    ('ParT_Light', '20251005_154731', '+0'),
    ('ParT_Light', '20251006_114628', '+5'),
    ('ParT_Light', '20251007_015709', '+10'),
]

zz4l_info_list = [
    ('CNN_EventCNN', '20251011_225958', '+0'),
    ('CNN_EventCNN', '20251013_003505', '+5'),
    ('CNN_EventCNN', '20251011_230612', '+10'),

    ('ParT_Light', '20251011_225958', '+0'),
    ('ParT_Light', '20251013_003505', '+5'),
    ('ParT_Light', '20251011_230612', '+10'),
]

# Jet-Flavor
luminosity_list = [100, 300, 900, 1800, 3000]
info_list = list(product(['diphoton', 'ex-diphoton'], diphoton_info_list)) + list(product(['zz4l', 'ex-zz4l'], zz4l_info_list))
for channel, (model, date_time, num_rot_aug) in info_list:
    for luminosity in luminosity_list:
        tmp_df = get_metrics(channel=channel, model=model, data_mode='jet_flavor', date_time=date_time, data_suffix=f'L{luminosity}', num_rnd=10)
        tmp_df['channel'] = channel
        tmp_df['date_time'] = date_time
        tmp_df['luminosity'] = luminosity
        tmp_df['num_rot_aug'] = num_rot_aug
        tmp_df['data_mode'] = 'jet_flavor'
        df = pd.concat([df, tmp_df], ignore_index=True)

# Additional training for zz4l at L = [9000, 18000, 30000]
additional_luminosity_list = [9000, 18000, 30000]
info_list = list(product(['zz4l', 'ex-zz4l'], [('CNN_EventCNN', '20251015_131405', '+0'), ('ParT_Light', '20251015_131405', '+0')]))
for channel, (model, date_time, num_rot_aug) in info_list:
    for luminosity in additional_luminosity_list:
        tmp_df = get_metrics(channel=channel, model=model, data_mode='jet_flavor', date_time=date_time, data_suffix=f'L{luminosity}', num_rnd=10)
        tmp_df['channel'] = channel
        tmp_df['date_time'] = date_time
        tmp_df['luminosity'] = luminosity
        tmp_df['num_rot_aug'] = num_rot_aug
        tmp_df['data_mode'] = 'jet_flavor'
        df = pd.concat([df, tmp_df], ignore_index=True)

# Supervised
supervised_list = [
    ('diphoton', 'CNN_EventCNN', '20251014_102416'),
    ('diphoton', 'ParT_Light', '20251014_102416'),
    ('ex-diphoton', 'CNN_EventCNN', '20251014_102416'),
    ('ex-diphoton', 'ParT_Light', '20251014_102416'),

    ('zz4l', 'CNN_EventCNN', '20251015_002601'),
    ('zz4l', 'ParT_Light', '20251015_002601'),
    ('ex-zz4l', 'CNN_EventCNN', '20251015_002601'),
    ('ex-zz4l', 'ParT_Light', '20251015_002601'),
]
for channel, model, date_time in supervised_list:
    tmp_df = get_metrics(channel=channel, model=model, data_mode='supervised', date_time=date_time, data_suffix='SV', num_rnd=10)
    tmp_df['channel'] = channel
    tmp_df['date_time'] = date_time
    tmp_df['data_mode'] = 'supervised'
    df = pd.concat([df, tmp_df], ignore_index=True)

df.loc[df['model'] == 'CNN_EventCNN', 'model_paper'] = 'CNN'
df.loc[df['model'] == 'ParT_Light', 'model_paper'] = 'ParT'

## CWoLa vs. Supervised (SV)

In [3]:
def plot(channel: str, legend_loc: str, show_fig: bool = False):
    fig, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=300)

    tmp_df = df.copy(deep=True)
    tmp_df = tmp_df[tmp_df['channel'] == channel]
    tmp_df['hue'] = tmp_df['model_paper']
    if channel in['zz4l', 'ex-zz4l']:
        tmp_df['luminosity'] = tmp_df['luminosity'] / 1000  # convert to ab^-1

    # after plotting lineplot
    g = (tmp_df[tmp_df['data_mode'] == 'jet_flavor']
        .groupby(['hue', 'luminosity'], as_index=False)['test_auc']
        .mean()
        .sort_values(['hue', 'luminosity']))

    # # Print out the data points
    # for model, dfm in g.groupby('hue'):
    #     print(f"\nModel: {model}")
    #     for x, y in zip(dfm['luminosity'], dfm['test_auc']):
    #         print(f"({x:g}, {y:.3f})")  # x prints 100, 300, 900, 1.8e+03 or 1.8 (if in ab^-1)

    # Jet-Flavor
    sns.lineplot(data=tmp_df[(tmp_df['data_mode'] == 'jet_flavor') & (tmp_df['num_rot_aug'] == '+0')], x='luminosity', y='test_auc', hue='hue', style='hue', markers=True, dashes=False, ax=ax)

    # Supervised
    ax.axhline(y=tmp_df[(tmp_df['data_mode'] == 'supervised') & (tmp_df['model_paper'] == 'CNN')]['test_auc'].mean(), color=default_colors[0], linestyle='--', linewidth=1.5, label='CNN-SV')
    ax.axhline(y=tmp_df[(tmp_df['data_mode'] == 'supervised') & (tmp_df['model_paper'] == 'ParT')]['test_auc'].mean(), color=default_colors[1], linestyle='--', linewidth=1.5, label='ParT-SV')

    # x-ticks and labels
    ax.set_xscale("log")
    if channel in ['zz4l', 'ex-zz4l']:
        xticks = luminosity_list + additional_luminosity_list
        xticks = [xtick / 1000 for xtick in xticks]
        xticks = [int(xtick) if xtick.is_integer() else xtick for xtick in xticks]
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticks)
        ax.set_ylim(0.45, 1.0)
        ax.set(xlabel=r"Luminosity [ab$^{-1}$]", ylabel="AUC")
    else:
        ax.set_xticks(luminosity_list)
        ax.set_xticklabels(luminosity_list)
        ax.set_ylim(0.6, 0.9)
        ax.set(xlabel=r"Luminosity [fb$^{-1}$]", ylabel="AUC")

    ax.legend(loc=legend_loc, frameon=FRAMEON)

    file_name = ROOT / 'figures' / f'AUC_CWoLa-vs-SV_{channel}.pdf'
    print(f"Saving figure to {file_name}")
    plt.tight_layout()
    fig.savefig(file_name, bbox_inches='tight')

    if show_fig:
        plt.show()
    else:
        plt.close(fig)

In [4]:
plot('diphoton', legend_loc='lower right', show_fig=False)
plot('ex-diphoton', legend_loc='lower right', show_fig=False)

Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-vs-SV_diphoton.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-vs-SV_ex-diphoton.pdf


In [5]:
plot('zz4l', legend_loc='center left', show_fig=False)
plot('ex-zz4l', legend_loc='center left', show_fig=False)

Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-vs-SV_zz4l.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-vs-SV_ex-zz4l.pdf


## CWoLa with Augmentations

In [6]:
def plot(channel: str, model_paper: str, legend_loc: str, ylim: tuple, show_fig: bool = False):
    fig, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=300)

    tmp_df = df.copy(deep=True)
    tmp_df = tmp_df[tmp_df['channel'] == channel]
    tmp_df = tmp_df[tmp_df['model_paper'] == model_paper]
    tmp_df = tmp_df[tmp_df['luminosity'] <= 3000]
    tmp_df['hue'] = "Augment " + tmp_df["num_rot_aug"]
    tmp_df.loc[tmp_df['hue'] == 'Augment +0', 'hue'] = 'Original'

    sns.lineplot(data=tmp_df, x='luminosity', y='test_auc', hue='hue', style='hue', markers=True, dashes=False, ax=ax)
    ax.set_xscale("log")
    ax.set_xticks(luminosity_list)
    ax.set_xticklabels(luminosity_list)
    ax.set_ylim(*ylim)
    ax.set(xlabel=r"Luminosity [fb$^{-1}$]", ylabel="AUC")
    ax.legend(loc=legend_loc, frameon=FRAMEON)

    file_name = ROOT / 'figures' / f'AUC_CWoLa-Aug_{channel}-{model_paper}.pdf'
    print(f"Saving figure to {file_name}")
    plt.tight_layout()
    fig.savefig(file_name, bbox_inches='tight')

    if show_fig:
        plt.show()
    else:
        plt.close(fig)

In [7]:
YLIM = (0.58, 0.77)
LEGEND_LOC = 'lower right'
plot('diphoton', 'CNN', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)
plot('diphoton', 'ParT', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)
plot('ex-diphoton', 'CNN', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)
plot('ex-diphoton', 'ParT', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)

Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_diphoton-CNN.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_diphoton-ParT.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_ex-diphoton-CNN.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_ex-diphoton-ParT.pdf


In [8]:
YLIM = (0.47, 0.74)
LEGEND_LOC = 'upper left'
plot('zz4l', 'CNN', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)
plot('zz4l', 'ParT', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)
plot('ex-zz4l', 'CNN', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)
plot('ex-zz4l', 'ParT', legend_loc=LEGEND_LOC, ylim=YLIM, show_fig=False)

Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_zz4l-CNN.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_zz4l-ParT.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_ex-zz4l-CNN.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_CWoLa-Aug_ex-zz4l-ParT.pdf


## $H \rightarrow \gamma \gamma$ apply to $H \rightarrow ZZ \rightarrow 4l$

- The inference used data and trained models that **WITHOUT** decay products.

In [9]:
inference_info_list = [
    # Removing also neighbors near decay product
    ('image', 'CNN_EventCNN', '20251005_154731', 'Original'),
    ('image', 'CNN_EventCNN', '20251006_114628', 'Augment +5'),
    ('image', 'CNN_EventCNN', '20251007_015709', 'Augment +10'),
    ('sequence', 'ParT_Light', '20251005_154731', 'Original'),
    ('sequence', 'ParT_Light', '20251006_114628', 'Augment +5'),
    ('sequence', 'ParT_Light', '20251007_015709', 'Augment +10'),
]

inf_df = pd.DataFrame()
for (data_format, model, date_time, aug_info) in inference_info_list:
    tmp_df = pd.read_csv(ROOT / 'output' / 'inference' / f"{model}_{date_time}.csv")
    tmp_df['hue'] = aug_info
    inf_df = pd.concat([inf_df, tmp_df], ignore_index=True)

inf_df.loc[inf_df['model'] == 'CNN_EventCNN', 'model_paper'] = 'CNN'
inf_df.loc[inf_df['model'] == 'ParT_Light', 'model_paper'] = 'ParT'

def plot(model_paper: str, show_fig: bool = False):
    fig, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=300)

    tmp_df = inf_df.copy(deep=True)
    tmp_df = tmp_df[tmp_df['model_paper'] == model_paper]
    sns.lineplot(data=tmp_df, x='luminosity', y='test_auc', hue='hue', style='hue', markers=True, dashes=False, ax=ax)
    ax.set_xscale("log")
    ax.set_xticks(luminosity_list)
    ax.set_xticklabels(luminosity_list)
    ax.set_ylim(0.58, 0.77)
    ax.set(xlabel=r"Luminosity [fb$^{-1}$]", ylabel="AUC")
    ax.legend(loc='lower right', frameon=FRAMEON)

    file_name = ROOT / 'figures' / f'AUC_transfer_{model_paper}.pdf'
    print(f"Saving figure to {file_name}")
    plt.tight_layout()
    fig.savefig(file_name, bbox_inches='tight')

    if show_fig:
        plt.show()
    else:
        plt.close(fig)

plot('CNN', show_fig=False)
plot('ParT', show_fig=False)

Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_transfer_CNN.pdf
Saving figure to /home/yianchen/NTUHEPML-CWoLa/figures/AUC_transfer_ParT.pdf
