Full thing with classes

Setup

In [None]:
from math import sqrt
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pandas

from library.plotting import make_plot_note, setup_high_quality_mpl_params

from library.approaches import (
    Shawns_Approach,
    Deep_Sets_Approach,
    Event_By_Event_Approach
)
from library.nn_training import select_device
from library.predict import Summary_Table

setup_high_quality_mpl_params()

device = select_device()
datasets_dir = "../../state/new_physics/data/processed"
models_dir = "../../state/new_physics/models"
plots_dir = "../../state/new_physics/plots"
summary_table = Summary_Table()


In [None]:
shawns_approach_gen = Shawns_Approach(
    device,
    "gen",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    regenerate_datasets=False,
    retrain_models=False,            
)

In [None]:
shawns_approach_det = Shawns_Approach(
    device,
    "det",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    regenerate_datasets=False,
    retrain_models=False,            
)

In [None]:
shawns_approach_gen_no_q2_veto = Shawns_Approach(
    device,
    "gen",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    q_squared_veto=False,
    regenerate_datasets=False,
    retrain_models=False,            
)

In [None]:
shawns_approach_det_no_q2_veto = Shawns_Approach(
    device,
    "det",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    q_squared_veto=False,
    regenerate_datasets=False,
    retrain_models=False,            
)

In [None]:
deep_sets_approach_gen = Deep_Sets_Approach(
    device,
    "gen",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    regenerate_datasets=False,
    retrain_models=False,
)

In [None]:
deep_sets_approach_det = Deep_Sets_Approach(
    device,
    "det",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    regenerate_datasets=False,
    retrain_models=False,
)

In [None]:
deep_sets_approach_gen_no_q2_veto = Deep_Sets_Approach(
    device,
    "gen",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    q_squared_veto=False,
    regenerate_datasets=False,
    retrain_models=False,
)

In [None]:
deep_sets_approach_det_no_q2_veto = Deep_Sets_Approach(
    device,
    "det",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    q_squared_veto=False,
    regenerate_datasets=False,
    retrain_models=False,
)

In [None]:
event_by_event_approach_gen = Event_By_Event_Approach(
    device,
    "gen",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    regenerate_datasets=False,
    retrain_model=False,
)

In [None]:
event_by_event_approach_det = Event_By_Event_Approach(
    device,
    "det",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    regenerate_datasets=False,
    retrain_model=False,
)

In [None]:
event_by_event_approach_det_no_q2_veto = Event_By_Event_Approach(
    device,
    "det",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    q_squared_veto=False,
    regenerate_datasets=False,
    retrain_model=False,
)

In [None]:
event_by_event_approach_gen_no_q2_veto = Event_By_Event_Approach(
    device,
    "gen",
    datasets_dir,
    models_dir,
    plots_dir,
    summary_table,
    q_squared_veto=False,
    regenerate_datasets=False,
    retrain_model=False,
)

In [None]:
summary_table.table

In [None]:
print(
    summary_table.table[["MSE", "MAE",]]
    .to_latex(float_format="%.3f")
)

print('\n')

print(
    summary_table.table[["Std. at NP", "Bias at NP"]]
    .to_latex(float_format="%.3f")
)

In [None]:
num_sets_for_sensitivity = 2000
num_sets_for_non_sensitivity = 44*50

levels = ["det", "gen"]
q_2_vetos = [True, False]
methods = ["Images", "Deep Sets", "Event by event"]
colors = ['#984ea3', '#999999', '#4daf4a',]
markers = ["o", "^", "s"]

y_lims = [(0, 0.18), (0, None), (0, 0.45), (-1, -0.7), (-0.15, 0.15)]

for col, y_lim in zip(summary_table.table.columns, y_lims):

    fig, ax = plt.subplots()
    
    for (method, color, marker), level, q_2_veto in product(zip(methods, colors, markers), levels, q_2_vetos):

        # if (method != "Images") and (method != "Deep Sets"):
        #     continue

        # if level == "gen" and not q_2_veto:
        #     continue
   
        y = summary_table.table.loc[pandas.IndexSlice[level, q_2_veto, method], col]
        x = y.index
        linestyle = "--"+marker if level=="det" else "-"+marker
        if level=="det" and not q_2_veto: linestyle = ":"+marker
        if level=="gen" and not q_2_veto: linestyle = "-."+marker
        ax.plot(x, y, linestyle, label=f"{method}, {level}, veto: {q_2_veto}", c=color, markersize=5, alpha=.8)
        if (col=="Bias at NP") or (col=="Mean at NP"):
            errors = (
                summary_table.table.loc[pandas.IndexSlice[level, q_2_veto, method], "Std. at NP"] 
                / sqrt(num_sets_for_sensitivity)
            )
            ax.errorbar(x=x, y=y, yerr=errors, fmt='none', elinewidth=0.5, capsize=0.5, color="black",)

        ax.set_ylim(y_lim)
        ax.set_ylabel(f"{col}")
        ax.set_xlabel("Number of events / set")
        ax.legend(ncols=2, markerscale=0.5, numpoints=1)
        if (col=="Bias at NP") or (col=="Mean at NP") or (col=="Std. at NP"):
            make_plot_note(ax, f"Num boots.: {num_sets_for_sensitivity}", fontsize="large")
        else: make_plot_note(ax, f"Num boots.: {num_sets_for_non_sensitivity}", fontsize="large")

    file_name = f"comp_{col}.png"
    file_path = Path(plots_dir).joinpath(file_name)
    plt.savefig(file_path, bbox_inches="tight")

    plt.show()
    plt.close()


In [None]:
# counting parameters:

def count_model_trainable_params(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params

print(
    "Images approach: ",
    count_model_trainable_params(shawns_approach_gen.models[70_000])
)
print(
    "Deep Sets approach: ",
    count_model_trainable_params(deep_sets_approach_gen.models[70_000])
)
print(
    "Event-by-event approach: ",
    count_model_trainable_params(event_by_event_approach_gen.model)
)