Setup

In [1]:
from pathlib import Path
from itertools import product

from torch.nn import MSELoss, L1Loss
import matplotlib as mpl
import matplotlib.pyplot as plt

from btokstmumu_ml_helpers.datasets.constants import Names_of_Levels, Names_of_q_Squared_Vetos, Numbers_of_Signal_Events_per_Set
from btokstmumu_ml_helpers.experiment.experiment import CNN_Group
from btokstmumu_ml_helpers.experiment.results_table import Results_Table
from btokstmumu_ml_helpers.experiment.constants import Paths_to_Directories
from btokstmumu_ml_helpers.models.hardware_util import select_device
from btokstmumu_ml_helpers.plot.linearity import plot_linearity

results_table = Results_Table()
device = select_device()

mpl.rcParams["figure.figsize"] = (6, 4)
mpl.rcParams["figure.dpi"] = 400
mpl.rcParams["axes.titlesize"] = 8
mpl.rcParams["figure.titlesize"] = 8
mpl.rcParams["figure.labelsize"] = 30
mpl.rcParams["text.usetex"] = True
mpl.rcParams["text.latex.preamble"] = r"\usepackage{bm}"
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.serif"] = ["Computer Modern"]
mpl.rcParams["font.size"] = 8
mpl.rcParams["axes.titley"] = None
mpl.rcParams["axes.titlepad"] = 2
mpl.rcParams["legend.fancybox"] = False
mpl.rcParams["legend.framealpha"] = 0
mpl.rcParams["legend.markerscale"] = 1
mpl.rcParams["legend.fontsize"] = 7.5

Device:  cuda


CNN

In [2]:
cnn_group = CNN_Group(
    model_variant="shawn",
    num_sets_per_label={6_000 : 583, 24_000 : 145, 70_000 : 50},
    num_sets_per_label_sensitivity=2_000,
    num_bins_per_dimension=50, #nominal is 10, retrain det_bkg models with 50
    q_squared_veto=Names_of_q_Squared_Vetos().resonances,
    std_scale=True,
    shuffle=True,
    uniform_label_counts=True,
    loss_fn=MSELoss(), # nominal MSE loss
    learning_rate=1e-3, # nominal 3e-4
    learning_rate_scheduler_reduction_factor=0.2, # nominal 0.97
    learning_rate_scheduler_patience=5,
    size_of_training_batch={6_000 : 373, 24_000 : 93, 70_000 : 32},
    size_of_evaluation_batch={6_000 : 373, 24_000 : 93, 70_000 : 32},
    number_of_epochs=50, # nominal 100
    number_of_epochs_between_checkpoints=1,
    results_table=results_table,
    device=device,
    bkg_fraction=0.44,
    bkg_charge_fraction=0.57,
    path_to_common_processed_datasets_dir="../../common/state/data/processed",
    path_to_local_processed_datasets_dir="../state/data",
    path_to_local_models_dir="../state/models"
)

In [3]:
cnn_group.train_subset(
    levels=[Names_of_Levels().detector_and_background,], 
    nums_signal_events_per_set=(6_000, 24_000, 70_000), 
    remake_datasets=True
)

Making images dataset.
Applying signal preprocessing.
Removing rows that have a NaN.
Number of NA values: 
 q_squared          0
costheta_mu      170
costheta_K       760
chi              760
dc9                0
dc9_bin_index      0
dtype: int64
Removed rows that have a NaN.
Applying q^2 veto.
Applied q^2 veto.
Applying standand scale.
Applying q^2 veto.
Applied q^2 veto.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb_bkg_train.parquet  dtype: float32
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\mix_sb_bkg_train.parquet  dtype: float32
Applying q^2 veto.
Applied q^2 veto.
Applying q^2 veto.
Applied q^2 veto.
Applied standard scale.
Shuffled dataframe.
Reducing events per label to lowest per label.
Shuffled dataframe.
Reduced events per label to lowest per label.
Applied signal preprocessing.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb_bkg_train.parquet  dtype: float32
Loa

In [4]:
cnn_group.evaluate_subset(
    levels=[Names_of_Levels().detector_and_background], 
    nums_signal_events_per_set=[6_000, 24_000, 70_000], 
    remake_datasets=True
)

Making images dataset.
Applying signal preprocessing.
Removing rows that have a NaN.
Number of NA values: 
 q_squared          0
costheta_mu      168
costheta_K       706
chi              706
dc9                0
dc9_bin_index      0
dtype: int64
Removed rows that have a NaN.
Applying q^2 veto.
Applied q^2 veto.
Applying standand scale.
Applying q^2 veto.
Applied q^2 veto.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb_bkg_train.parquet  dtype: float32
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\mix_sb_bkg_train.parquet  dtype: float32
Applying q^2 veto.
Applied q^2 veto.
Applying q^2 veto.
Applied q^2 veto.
Applied standard scale.
Shuffled dataframe.
Reducing events per label to lowest per label.
Shuffled dataframe.
Reduced events per label to lowest per label.
Applied signal preprocessing.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb_bkg_val.parquet  dtype: float32
Loade

  self.table.loc[


Unloaded datasets.
Unloaded datasets.
Making images dataset.
Applying signal preprocessing.
Removing rows that have a NaN.
Number of NA values: 
 q_squared          0
costheta_mu      168
costheta_K       706
chi              706
dc9                0
dc9_bin_index      0
dtype: int64
Removed rows that have a NaN.
Applying q^2 veto.
Applied q^2 veto.
Applying standand scale.
Applying q^2 veto.
Applied q^2 veto.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb_bkg_train.parquet  dtype: float32
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\mix_sb_bkg_train.parquet  dtype: float32
Applying q^2 veto.
Applied q^2 veto.
Applying q^2 veto.
Applied q^2 veto.
Applied standard scale.
Shuffled dataframe.
Reducing events per label to lowest per label.
Shuffled dataframe.
Reduced events per label to lowest per label.
Applied signal preprocessing.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb

  self.table.loc[


Unloaded datasets.
Unloaded datasets.
Making images dataset.
Applying signal preprocessing.
Removing rows that have a NaN.
Number of NA values: 
 q_squared          0
costheta_mu      168
costheta_K       706
chi              706
dc9                0
dc9_bin_index      0
dtype: int64
Removed rows that have a NaN.
Applying q^2 veto.
Applied q^2 veto.
Applying standand scale.
Applying q^2 veto.
Applied q^2 veto.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb_bkg_train.parquet  dtype: float32
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\mix_sb_bkg_train.parquet  dtype: float32
Applying q^2 veto.
Applied q^2 veto.
Applying q^2 veto.
Applied q^2 veto.
Applied standard scale.
Shuffled dataframe.
Reducing events per label to lowest per label.
Shuffled dataframe.
Reduced events per label to lowest per label.
Applied signal preprocessing.
Loaded background file: ..\..\common\state\data\processed\aggregated_generic\charge_sb

  self.table.loc[


Unloaded datasets.
Unloaded datasets.


In [None]:
fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, layout="compressed")

names_of_levels = {Names_of_Levels().generator : "Generator", Names_of_Levels().detector : "Detector", Names_of_Levels().detector_and_background : "Detector and Bkg."}

for (level, num_events_per_set), ax in zip(product(Names_of_Levels().tuple_, Numbers_of_Events_per_Set().tuple_), axs.flat):
    
    plot_linearity(
        linearity_test_results=cnn_group.results[level][num_events_per_set].linearity_results, 
        ax=ax,
    )

    ax.set_title(
        f"Level: {names_of_levels[level]}"
        f"\nEvents/set: {num_events_per_set}"
        "\n" + r"Sets/$\delta C_9$: " + f"{cnn_group.num_sets_per_label[num_events_per_set]}", 
        loc="left"
    )

axs.flat[0].legend()
# fig.suptitle(f"CNN, bins/dim.: {cnn_group.num_bins_per_dimension}\n", x=0.02, horizontalalignment="left")
fig.suptitle(f"CNN\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"Actual $\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"Predicted $\delta C_9$", fontsize=11, y=0.45)

plt.savefig(Paths_to_Directories().path_to_plots_dir.joinpath("cnn_grid_lin.png"), bbox_inches="tight")
plt.close()

In [7]:
plots_dir = Path("../state/plots")

names_of_levels = {Names_of_Levels().generator : "Generator", Names_of_Levels().detector : "Detector", Names_of_Levels().detector_and_background : "Detector and Bkg."}

fig, ax = plt.subplots()

level = Names_of_Levels().detector_and_background
num_events_per_set = 70_000

plot_linearity(
    linearity_test_results=cnn_group.results[level][num_events_per_set].linearity_results,
    ax=ax,
)

ax.set_title(
    f"Level: {names_of_levels[level]}"
    f"\nSignal events/set: {num_events_per_set}"
    "\n" + r"Sets/$\delta C_9$: " + f"{cnn_group.num_sets_per_label[num_events_per_set]}", 
    loc="left"
)

ax.legend()

fig.suptitle(f"CNN\n", x=0.02, horizontalalignment="left")
ax.set_xlabel(r"Actual $\delta C_9$", fontsize=11)
ax.set_ylabel(r"Predicted $\delta C_9$", fontsize=11)

plt.savefig(plots_dir.joinpath(f"cnn_lin_{level}_{num_events_per_set}.png"), bbox_inches="tight")
plt.close()