$B \rightarrow K^* \ell \ell$  machine learning experiment

Setup

In [1]:
from itertools import product

from torch.nn import MSELoss, CrossEntropyLoss
import matplotlib as mpl
import matplotlib.pyplot as plt

from helpers.datasets.make_and_save.aggregated_signal import Aggregated_Signal_Dataframe_Handler
from helpers.datasets.constants import Names_of_Levels, Names_of_q_Squared_Vetos, Raw_Signal_Trial_Ranges, Numbers_of_Events_per_Set, Names_of_Splits, Names_of_Labels
from helpers.experiment.experiment import CNN_Group, Deep_Sets_Group, Event_by_Event_Group
from helpers.experiment.results_table import Results_Table
from helpers.experiment.constants import Paths_to_Directories, delta_C9_value_new_physics, delta_C9_value_standard_model
from helpers.models.hardware_util import select_device
from helpers.experiment.experiment import evaluate_model
from helpers.datasets.settings.settings import Binned_Sets_Dataset_Settings
from helpers.datasets.datasets import Unbinned_Sets_Dataset, Binned_Sets_Dataset, Images_Dataset
from helpers.datasets.make_and_save.preprocessing import apply_q_squared_veto

from helpers.plot.linearity import plot_linearity
from helpers.plot.probabilities import plot_log_probability_distribution_examples

results_table = Results_Table()
device = select_device()

mpl.rcParams["figure.figsize"] = (6, 4)
mpl.rcParams["figure.dpi"] = 400
mpl.rcParams["axes.titlesize"] = 8
mpl.rcParams["figure.titlesize"] = 8
mpl.rcParams["figure.labelsize"] = 30
mpl.rcParams["text.usetex"] = True
mpl.rcParams["text.latex.preamble"] = r"\usepackage{bm}"
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.serif"] = ["Computer Modern"]
mpl.rcParams["font.size"] = 8
mpl.rcParams["axes.titley"] = None
mpl.rcParams["axes.titlepad"] = 2
mpl.rcParams["legend.fancybox"] = False
mpl.rcParams["legend.framealpha"] = 0
mpl.rcParams["legend.markerscale"] = 1
mpl.rcParams["legend.fontsize"] = 7.5

Device:  cuda


Remake aggregated signal dataframe files

In [None]:
for level in (Names_of_Levels().generator, Names_of_Levels().detector):
    for trial_range in Raw_Signal_Trial_Ranges().tuple_:
        
        Aggregated_Signal_Dataframe_Handler(
            path_to_main_datasets_dir=Paths_to_Directories().path_to_main_datasets_dir,
            level=level,
            trial_range=trial_range
        ).make_and_save(Paths_to_Directories().path_to_raw_signal_dir)

Deep Sets

In [None]:
deep_sets_group = Deep_Sets_Group(
    num_sets_per_label={6_000 : 583, 24_000 : 145, 70_000 : 50},
    num_sets_per_label_sensitivity=2_000,
    q_squared_veto=Names_of_q_Squared_Vetos().resonances,
    std_scale=True,
    shuffle=True,
    uniform_label_counts=True,
    loss_fn=MSELoss(),
    learning_rate=3e-4,
    learning_rate_scheduler_reduction_factor=0.97,
    size_of_training_batch={6_000 : 373, 24_000 : 93, 70_000 : 32},
    size_of_evaluation_batch={6_000 : 373, 24_000 : 93, 70_000 : 32},
    number_of_epochs=100,
    number_of_epochs_between_checkpoints=1,
    results_table=results_table,
    device=device,
    bkg_fraction=0.44,
    bkg_charge_fraction=0.57
)

# deep_sets_group.train_all(remake_datasets=True)

# deep_sets_group.train_subset([Names_of_Levels().detector_and_background], [6_000, 24_000, 70_000], remake_datasets=True)

deep_sets_group.evaluate_all(remake_datasets=False)

In [None]:
fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, layout="compressed")

levels = Names_of_Levels().tuple_
names_of_levels = {
    Names_of_Levels().generator : "Generator", 
    Names_of_Levels().detector : "Detector", 
    Names_of_Levels().detector_and_background : "Detector and Bkg."
}

for (level, num_events_per_set), ax in zip(product(levels, Numbers_of_Events_per_Set().tuple_), axs.flat):
    
    plot_linearity(
        linearity_test_results=deep_sets_group.results[level][num_events_per_set].linearity_results, 
        ax=ax,
    )

    ax.set_title(
        f"Level: {names_of_levels[level]}"
        f"\nEvents/set: {num_events_per_set}"
        "\n" + r"Sets/$\delta C_9$: " + f"{deep_sets_group.get_individual(level=level, num_events_per_set=num_events_per_set).evaluation_dataset_settings.set.num_sets_per_label}", 
        loc="left"
    )

axs.flat[0].legend()
fig.suptitle(f"Deep sets\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"Actual $\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"Predicted $\delta C_9$", fontsize=11, y=0.45)

plt.savefig(Paths_to_Directories().path_to_plots_dir.joinpath("deep_sets_grid_lin.png"), bbox_inches="tight")
plt.close()

CNN

In [3]:
cnn_group = CNN_Group(
    num_sets_per_label={6_000 : 583, 24_000 : 145, 70_000 : 50},
    num_sets_per_label_sensitivity=2_000,
    num_bins_per_dimension=50, #nominal is 10, retrain det_bkg models with 50
    q_squared_veto=Names_of_q_Squared_Vetos().resonances,
    std_scale=True,
    shuffle=True,
    uniform_label_counts=True,
    loss_fn=MSELoss(),
    learning_rate=1e-3, # nominal 3e-4
    learning_rate_scheduler_reduction_factor=0.97, # nominal 0.97
    size_of_training_batch={6_000 : 373, 24_000 : 93, 70_000 : 32},
    size_of_evaluation_batch={6_000 : 373, 24_000 : 93, 70_000 : 32},
    number_of_epochs=50, # nominal 100
    number_of_epochs_between_checkpoints=1,
    results_table=results_table,
    device=device,
    bkg_fraction=0.44,
    bkg_charge_fraction=0.57
)

# cnn_group.train_subset(levels=[Names_of_Levels().detector_and_background,], nums_events_per_set=(24_000,), remake_datasets=False)
cnn_group.evaluate_all(remake_datasets=False)

Loaded tensor of shape: torch.Size([2200, 1, 10, 10, 10]) from ..\..\state\new_physics\data\processed\images_gen_q2v_resonances\70000_eval_features.pt
Loaded tensor of shape: torch.Size([2200]) from ..\..\state\new_physics\data\processed\images_gen_q2v_resonances\70000_eval_labels.pt
Loaded tensor of shape: torch.Size([2000, 1, 10, 10, 10]) from ..\..\state\new_physics\data\processed\images_gen_q2v_resonances\70000_eval_sens_features.pt
Loaded tensor of shape: torch.Size([2000]) from ..\..\state\new_physics\data\processed\images_gen_q2v_resonances\70000_eval_sens_labels.pt


RuntimeError: Error(s) in loading state_dict for CNN_Model_Logic_Shawn:
	Missing key(s) in state_dict: "convolution_layers.1.weight", "convolution_layers.1.bias", "convolution_layers.1.running_mean", "convolution_layers.1.running_var", "convolution_layers.4.convolution_block.1.weight", "convolution_layers.4.convolution_block.1.bias", "convolution_layers.4.convolution_block.1.running_mean", "convolution_layers.4.convolution_block.1.running_var", "convolution_layers.4.convolution_block.3.weight", "convolution_layers.4.convolution_block.3.bias", "convolution_layers.4.convolution_block.4.weight", "convolution_layers.4.convolution_block.4.bias", "convolution_layers.4.convolution_block.4.running_mean", "convolution_layers.4.convolution_block.4.running_var", "convolution_layers.5.convolution_block.1.weight", "convolution_layers.5.convolution_block.1.bias", "convolution_layers.5.convolution_block.1.running_mean", "convolution_layers.5.convolution_block.1.running_var", "convolution_layers.5.convolution_block.3.weight", "convolution_layers.5.convolution_block.3.bias", "convolution_layers.5.convolution_block.4.weight", "convolution_layers.5.convolution_block.4.bias", "convolution_layers.5.convolution_block.4.running_mean", "convolution_layers.5.convolution_block.4.running_var", "convolution_layers.6.convolution_block.1.weight", "convolution_layers.6.convolution_block.1.bias", "convolution_layers.6.convolution_block.1.running_mean", "convolution_layers.6.convolution_block.1.running_var", "convolution_layers.6.convolution_block.3.weight", "convolution_layers.6.convolution_block.3.bias", "convolution_layers.6.convolution_block.4.weight", "convolution_layers.6.convolution_block.4.bias", "convolution_layers.6.convolution_block.4.running_mean", "convolution_layers.6.convolution_block.4.running_var", "convolution_layers.7.convolution_block_a.0.weight", "convolution_layers.7.convolution_block_a.0.bias", "convolution_layers.7.convolution_block_a.1.weight", "convolution_layers.7.convolution_block_a.1.bias", "convolution_layers.7.convolution_block_a.1.running_mean", "convolution_layers.7.convolution_block_a.1.running_var", "convolution_layers.7.convolution_block_a.3.weight", "convolution_layers.7.convolution_block_a.3.bias", "convolution_layers.7.convolution_block_a.4.weight", "convolution_layers.7.convolution_block_a.4.bias", "convolution_layers.7.convolution_block_a.4.running_mean", "convolution_layers.7.convolution_block_a.4.running_var", "convolution_layers.7.convolution_block_b.0.weight", "convolution_layers.7.convolution_block_b.0.bias", "convolution_layers.7.convolution_block_b.1.weight", "convolution_layers.7.convolution_block_b.1.bias", "convolution_layers.7.convolution_block_b.1.running_mean", "convolution_layers.7.convolution_block_b.1.running_var", "convolution_layers.8.convolution_block.1.weight", "convolution_layers.8.convolution_block.1.bias", "convolution_layers.8.convolution_block.1.running_mean", "convolution_layers.8.convolution_block.1.running_var", "convolution_layers.8.convolution_block.3.weight", "convolution_layers.8.convolution_block.3.bias", "convolution_layers.8.convolution_block.4.weight", "convolution_layers.8.convolution_block.4.bias", "convolution_layers.8.convolution_block.4.running_mean", "convolution_layers.8.convolution_block.4.running_var", "convolution_layers.9.convolution_block.1.weight", "convolution_layers.9.convolution_block.1.bias", "convolution_layers.9.convolution_block.1.running_mean", "convolution_layers.9.convolution_block.1.running_var", "convolution_layers.9.convolution_block.3.weight", "convolution_layers.9.convolution_block.3.bias", "convolution_layers.9.convolution_block.4.weight", "convolution_layers.9.convolution_block.4.bias", "convolution_layers.9.convolution_block.4.running_mean", "convolution_layers.9.convolution_block.4.running_var", "convolution_layers.10.convolution_block.1.weight", "convolution_layers.10.convolution_block.1.bias", "convolution_layers.10.convolution_block.1.running_mean", "convolution_layers.10.convolution_block.1.running_var", "convolution_layers.10.convolution_block.3.weight", "convolution_layers.10.convolution_block.3.bias", "convolution_layers.10.convolution_block.4.weight", "convolution_layers.10.convolution_block.4.bias", "convolution_layers.10.convolution_block.4.running_mean", "convolution_layers.10.convolution_block.4.running_var", "convolution_layers.11.convolution_block_a.0.weight", "convolution_layers.11.convolution_block_a.0.bias", "convolution_layers.11.convolution_block_a.1.weight", "convolution_layers.11.convolution_block_a.1.bias", "convolution_layers.11.convolution_block_a.1.running_mean", "convolution_layers.11.convolution_block_a.1.running_var", "convolution_layers.11.convolution_block_a.3.weight", "convolution_layers.11.convolution_block_a.3.bias", "convolution_layers.11.convolution_block_a.4.weight", "convolution_layers.11.convolution_block_a.4.bias", "convolution_layers.11.convolution_block_a.4.running_mean", "convolution_layers.11.convolution_block_a.4.running_var", "convolution_layers.11.convolution_block_b.0.weight", "convolution_layers.11.convolution_block_b.0.bias", "convolution_layers.11.convolution_block_b.1.weight", "convolution_layers.11.convolution_block_b.1.bias", "convolution_layers.11.convolution_block_b.1.running_mean", "convolution_layers.11.convolution_block_b.1.running_var", "convolution_layers.12.convolution_block.1.weight", "convolution_layers.12.convolution_block.1.bias", "convolution_layers.12.convolution_block.1.running_mean", "convolution_layers.12.convolution_block.1.running_var", "convolution_layers.12.convolution_block.3.weight", "convolution_layers.12.convolution_block.3.bias", "convolution_layers.12.convolution_block.4.weight", "convolution_layers.12.convolution_block.4.bias", "convolution_layers.12.convolution_block.4.running_mean", "convolution_layers.12.convolution_block.4.running_var", "convolution_layers.13.convolution_block.1.weight", "convolution_layers.13.convolution_block.1.bias", "convolution_layers.13.convolution_block.1.running_mean", "convolution_layers.13.convolution_block.1.running_var", "convolution_layers.13.convolution_block.3.weight", "convolution_layers.13.convolution_block.3.bias", "convolution_layers.13.convolution_block.4.weight", "convolution_layers.13.convolution_block.4.bias", "convolution_layers.13.convolution_block.4.running_mean", "convolution_layers.13.convolution_block.4.running_var", "convolution_layers.14.convolution_block.0.weight", "convolution_layers.14.convolution_block.0.bias", "convolution_layers.14.convolution_block.1.weight", "convolution_layers.14.convolution_block.1.bias", "convolution_layers.14.convolution_block.1.running_mean", "convolution_layers.14.convolution_block.1.running_var", "convolution_layers.14.convolution_block.3.weight", "convolution_layers.14.convolution_block.3.bias", "convolution_layers.14.convolution_block.4.weight", "convolution_layers.14.convolution_block.4.bias", "convolution_layers.14.convolution_block.4.running_mean", "convolution_layers.14.convolution_block.4.running_var", "convolution_layers.15.convolution_block.0.weight", "convolution_layers.15.convolution_block.0.bias", "convolution_layers.15.convolution_block.1.weight", "convolution_layers.15.convolution_block.1.bias", "convolution_layers.15.convolution_block.1.running_mean", "convolution_layers.15.convolution_block.1.running_var", "convolution_layers.15.convolution_block.3.weight", "convolution_layers.15.convolution_block.3.bias", "convolution_layers.15.convolution_block.4.weight", "convolution_layers.15.convolution_block.4.bias", "convolution_layers.15.convolution_block.4.running_mean", "convolution_layers.15.convolution_block.4.running_var", "convolution_layers.16.convolution_block.0.weight", "convolution_layers.16.convolution_block.0.bias", "convolution_layers.16.convolution_block.1.weight", "convolution_layers.16.convolution_block.1.bias", "convolution_layers.16.convolution_block.1.running_mean", "convolution_layers.16.convolution_block.1.running_var", "convolution_layers.16.convolution_block.3.weight", "convolution_layers.16.convolution_block.3.bias", "convolution_layers.16.convolution_block.4.weight", "convolution_layers.16.convolution_block.4.bias", "convolution_layers.16.convolution_block.4.running_mean", "convolution_layers.16.convolution_block.4.running_var", "convolution_layers.17.convolution_block_a.0.weight", "convolution_layers.17.convolution_block_a.0.bias", "convolution_layers.17.convolution_block_a.1.weight", "convolution_layers.17.convolution_block_a.1.bias", "convolution_layers.17.convolution_block_a.1.running_mean", "convolution_layers.17.convolution_block_a.1.running_var", "convolution_layers.17.convolution_block_a.3.weight", "convolution_layers.17.convolution_block_a.3.bias", "convolution_layers.17.convolution_block_a.4.weight", "convolution_layers.17.convolution_block_a.4.bias", "convolution_layers.17.convolution_block_a.4.running_mean", "convolution_layers.17.convolution_block_a.4.running_var", "convolution_layers.17.convolution_block_b.0.weight", "convolution_layers.17.convolution_block_b.0.bias", "convolution_layers.17.convolution_block_b.1.weight", "convolution_layers.17.convolution_block_b.1.bias", "convolution_layers.17.convolution_block_b.1.running_mean", "convolution_layers.17.convolution_block_b.1.running_var", "convolution_layers.18.convolution_block.0.weight", "convolution_layers.18.convolution_block.0.bias", "convolution_layers.18.convolution_block.1.weight", "convolution_layers.18.convolution_block.1.bias", "convolution_layers.18.convolution_block.1.running_mean", "convolution_layers.18.convolution_block.1.running_var", "convolution_layers.18.convolution_block.3.weight", "convolution_layers.18.convolution_block.3.bias", "convolution_layers.18.convolution_block.4.weight", "convolution_layers.18.convolution_block.4.bias", "convolution_layers.18.convolution_block.4.running_mean", "convolution_layers.18.convolution_block.4.running_var", "convolution_layers.19.convolution_block.0.weight", "convolution_layers.19.convolution_block.0.bias", "convolution_layers.19.convolution_block.1.weight", "convolution_layers.19.convolution_block.1.bias", "convolution_layers.19.convolution_block.1.running_mean", "convolution_layers.19.convolution_block.1.running_var", "convolution_layers.19.convolution_block.3.weight", "convolution_layers.19.convolution_block.3.bias", "convolution_layers.19.convolution_block.4.weight", "convolution_layers.19.convolution_block.4.bias", "convolution_layers.19.convolution_block.4.running_mean", "convolution_layers.19.convolution_block.4.running_var", "dense_layers.3.weight", "dense_layers.3.bias". 
	Unexpected key(s) in state_dict: "convolution_layers.3.convolution_block.0.weight", "convolution_layers.3.convolution_block.0.bias", "convolution_layers.3.convolution_block.2.weight", "convolution_layers.3.convolution_block.2.bias", "convolution_layers.4.convolution_block.2.weight", "convolution_layers.4.convolution_block.2.bias", "convolution_layers.5.convolution_block.2.weight", "convolution_layers.5.convolution_block.2.bias", "convolution_layers.6.convolution.weight", "convolution_layers.6.convolution.bias", "convolution_layers.6.convolution_block.2.weight", "convolution_layers.6.convolution_block.2.bias", "convolution_layers.7.convolution_block.0.weight", "convolution_layers.7.convolution_block.0.bias", "convolution_layers.7.convolution_block.2.weight", "convolution_layers.7.convolution_block.2.bias", "convolution_layers.8.convolution_block.2.weight", "convolution_layers.8.convolution_block.2.bias", "convolution_layers.9.convolution_block.2.weight", "convolution_layers.9.convolution_block.2.bias", "convolution_layers.10.convolution.weight", "convolution_layers.10.convolution.bias", "convolution_layers.10.convolution_block.2.weight", "convolution_layers.10.convolution_block.2.bias", "convolution_layers.11.convolution_block.0.weight", "convolution_layers.11.convolution_block.0.bias", "convolution_layers.11.convolution_block.2.weight", "convolution_layers.11.convolution_block.2.bias", "convolution_layers.12.convolution_block.2.weight", "convolution_layers.12.convolution_block.2.bias", "convolution_layers.13.convolution_block.2.weight", "convolution_layers.13.convolution_block.2.bias", "dense_layers.2.weight", "dense_layers.2.bias". 
	size mismatch for convolution_layers.0.weight: copying a param with shape torch.Size([16, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 1, 7, 7, 7]).
	size mismatch for convolution_layers.4.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3, 3]).
	size mismatch for convolution_layers.4.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for convolution_layers.5.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3, 3]).
	size mismatch for convolution_layers.5.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for convolution_layers.6.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3, 3]).
	size mismatch for convolution_layers.6.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for convolution_layers.8.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3, 3]).
	size mismatch for convolution_layers.8.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for convolution_layers.9.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3, 3]).
	size mismatch for convolution_layers.9.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for convolution_layers.10.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3, 3]).
	size mismatch for convolution_layers.10.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for convolution_layers.12.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3, 3]).
	size mismatch for convolution_layers.12.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for convolution_layers.13.convolution_block.0.weight: copying a param with shape torch.Size([16, 16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3, 3]).
	size mismatch for convolution_layers.13.convolution_block.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for dense_layers.0.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
	size mismatch for dense_layers.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([1000]).

In [None]:
fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, layout="compressed")

names_of_levels = {Names_of_Levels().generator : "Generator", Names_of_Levels().detector : "Detector", Names_of_Levels().detector_and_background : "Detector and Bkg."}

for (level, num_events_per_set), ax in zip(product(Names_of_Levels().tuple_, Numbers_of_Events_per_Set().tuple_), axs.flat):
    
    plot_linearity(
        linearity_test_results=cnn_group.results[level][num_events_per_set].linearity_results, 
        ax=ax,
    )

    ax.set_title(
        f"Level: {names_of_levels[level]}"
        f"\nEvents/set: {num_events_per_set}"
        "\n" + r"Sets/$\delta C_9$: " + f"{cnn_group.num_sets_per_label[num_events_per_set]}", 
        loc="left"
    )

axs.flat[0].legend()
# fig.suptitle(f"CNN, bins/dim.: {cnn_group.num_bins_per_dimension}\n", x=0.02, horizontalalignment="left")
fig.suptitle(f"CNN\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"Actual $\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"Predicted $\delta C_9$", fontsize=11, y=0.45)

plt.savefig(Paths_to_Directories().path_to_plots_dir.joinpath("cnn_grid_lin.png"), bbox_inches="tight")
plt.close()

In [None]:
import numpy

def plot_image_slices(
    image,
    norm, 
    cmap,
    ax_3d,
    num_slices=3, 
):  

    def xy_plane_at(z_position):
        x, y = numpy.indices(
            (
                axis_dimension_from_cartesian["x"] + 1, 
                axis_dimension_from_cartesian["y"] + 1
            )
        )
        z = numpy.full(
            (
                axis_dimension_from_cartesian["x"] + 1, 
                axis_dimension_from_cartesian["y"] + 1
            ), 
            z_position
        )
        return x, y, z
    
    def plot_slice(z_index):
        x, y, z = xy_plane_at(z_index) 
        ax_3d.plot_surface(
            x, y, z, 
            rstride=1, cstride=1, 
            facecolors=colors[:,:,z_index], 
            shade=False
        )

    def plot_outline(z_index, offset=0.3):
        x, y, z = xy_plane_at(z_index - offset)
        ax_3d.plot_surface(
            x, y, z, 
            rstride=1, 
            cstride=1, 
            shade=False,
            color="#f2f2f2",
            edgecolor="#f2f2f2"
        )

    image = image.squeeze().cpu()
    colors = cmap(norm(image))
    
    axis_index_from_cartesian = {
        "x": 0,
        "y": 1,
        "z": 2
    }
    axis_dimension_from_cartesian = {
        "x": image.shape[axis_index_from_cartesian["x"]],
        "y": image.shape[axis_index_from_cartesian["y"]],
        "z": image.shape[axis_index_from_cartesian["z"]]
    }
    z_indices = numpy.linspace( 
        start=0, 
        stop=axis_dimension_from_cartesian["z"]-1, 
        num=num_slices, 
        dtype=int  # forces integer indices
    ) 
    for i in z_indices:
        plot_outline(i)
        plot_slice(i)


    ax_labels = {
        "x": r"$\cos\theta_\mu$",
        "y": r"$\cos\theta_K$",
        "z": r"$\chi$"
    }
    # ax_3d.set_axis_off()
    ax_3d.tick_params(
        axis="both",
        which="both",
        bottom=False,
        top=True,
        labelbottom=False,
        labeltop=False,
        labelleft=False,
        labelright=False
    )
    ax_3d.set_xlabel(ax_labels["x"], labelpad=-16)
    ax_3d.set_ylabel(ax_labels["y"], labelpad=-16)
    ax_3d.set_zlabel(ax_labels["z"], labelpad=-16)
    # ax_3d.set_box_aspect(None, zoom=0.85)
    

for delta_c9 in (delta_C9_value_standard_model, delta_C9_value_new_physics):

    fig, axs = plt.subplots(2, 3, subplot_kw={"projection":"3d"}, layout="compressed")
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    cmap = plt.cm.magma
    cbar = fig.colorbar(
        mpl.cm.ScalarMappable(norm=norm, cmap=cmap), 
        ax=axs, 
        location="right", 
        shrink=0.6,     
    )
    cbar.set_label(r"Normalized ${q^2}$ (Avg.)", size=11)
    cbar.set_ticks([])
        
    levels = (Names_of_Levels().generator, Names_of_Levels().detector)
    names_of_levels = {Names_of_Levels().generator : "Generator", Names_of_Levels().detector : "Detector"}

    for (level, num_events_per_set), ax in zip(product(levels, Numbers_of_Events_per_Set().tuple_), axs.flat):

        dataset = Images_Dataset(settings=cnn_group.get_individual(level=level, num_events_per_set=num_events_per_set).evaluation_dataset_settings)
        dataset.load()
        plot_image_slices(
            image=dataset.features[dataset.labels==delta_c9][0],
            norm=norm,
            cmap=cmap,
            ax_3d=ax
        )
        ax.set_title(
            (
                f"Level: {names_of_levels[level]}"
                f"\nEvents: {num_events_per_set}"
            ),
            loc="left",
            y=0.97
        )

    delta_c9_description = (
        r"SM ($\delta C_9 = " + f"{delta_C9_value_standard_model}" + r"$)" if delta_c9 == delta_C9_value_standard_model
        else r"NP ($\delta C_9 = " + f"{delta_C9_value_new_physics}" + r"$)" if delta_c9 == delta_C9_value_new_physics
        else None
    )
    if delta_c9_description is None: raise ValueError

    fig.suptitle(
        (
            f"Images, "
            + f"bins/dim.: {cnn_group.num_bins_per_dimension}, "
            + delta_c9_description
            + "\n"
        ), 
        x=0.02, 
        horizontalalignment="left"
    )

    save_name = (
        "image_grid_SM.png" if delta_c9 == delta_C9_value_standard_model
        else "image_grid_NP.png" if delta_c9 == delta_C9_value_new_physics
        else None
    )
    if save_name is None: raise ValueError

    plt.savefig(Paths_to_Directories().path_to_plots_dir.joinpath(save_name), bbox_inches="tight")
    plt.close()


Event-by-event

In [None]:
event_by_event_group = Event_by_Event_Group(
    num_evaluation_sets_per_label={6_000 : 583, 24_000 : 145, 70_000 : 50},
    num_evaluation_sets_per_label_sensitivity=2_000,
    q_squared_veto=Names_of_q_Squared_Vetos().resonances,
    std_scale=True,
    shuffle=True,
    uniform_label_counts=True,
    loss_fn=CrossEntropyLoss(),
    learning_rate=3e-3,
    learning_rate_scheduler_reduction_factor=0.95,
    size_of_training_batch=10_000,
    size_of_evaluation_batch=10_000,
    number_of_epochs=300,
    number_of_epochs_between_checkpoints=2,
    results_table=results_table,
    device=device
)

# event_by_event_group.train_subset(levels=[Names_of_Levels().detector], remake_datasets=True)
event_by_event_group.evaluate_all(remake_datasets=False)

In [None]:
fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, layout="compressed")

names_of_levels = {Names_of_Levels().generator : "Generator", Names_of_Levels().detector : "Detector"}

for (level, num_events_per_set), ax in zip(product(event_by_event_group.possible_levels, Numbers_of_Events_per_Set().tuple_), axs.flat):
    
    plot_linearity(
        linearity_test_results=event_by_event_group.results[level][num_events_per_set].linearity_results, 
        ax=ax,
    )

    ax.set_title(
        f"Level: {names_of_levels[level]}"
        f"\nEvents/set: {num_events_per_set}"
        "\n" + r"Sets/$\delta C_9$: " + f"{event_by_event_group.num_evaluation_sets_per_label[num_events_per_set]}", 
        loc="left"
    )

axs.flat[0].legend()
fig.suptitle("Event-by-event\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"Actual $\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"Predicted $\delta C_9$", fontsize=11, y=0.45)

plt.savefig(Paths_to_Directories().path_to_plots_dir.joinpath("ebe_grid_lin.png"), bbox_inches="tight")
plt.close()



fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, layout="compressed")

for (level, num_events_per_set), ax in zip(product(event_by_event_group.possible_levels, Numbers_of_Events_per_Set().tuple_), axs.flat):
    
    dataset = Binned_Sets_Dataset(settings=event_by_event_group.get_individual(level)._get_evaluation_set_dataset_settings(num_events_per_set))
    dataset.load()

    plot_log_probability_distribution_examples(
        log_probabilities=event_by_event_group.results[level][num_events_per_set].log_probabilities, 
        binned_labels=dataset.labels,
        bin_map=dataset.bin_map,
        ax=ax
    )

    ax.set_title(f"Level: {names_of_levels[level]}\nEvents: {num_events_per_set}", loc="left")

axs.flat[0].legend(loc="lower right", markerscale=2)
fig.suptitle("Event-by-event\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"$\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"$\log\;p(\delta C_9 | x_1, ..., x_N)$", fontsize=11, y=0.45)

plt.savefig(Paths_to_Directories().path_to_plots_dir.joinpath("ebe_grid_proba.png"), bbox_inches="tight")
plt.close()

In [None]:
import pandas
import matplotlib.pyplot as plt

In [None]:
charge_train = pandas.read_pickle("../../state/new_physics/data/raw/bkg/mu_sideb_generic_charge_train.pkl")
mix_train = pandas.read_pickle("../../state/new_physics/data/raw/bkg/mu_sideb_generic_mix_train.pkl")
all_train = pandas.concat([charge_train, mix_train])

charge_eval = pandas.read_pickle("../../state/new_physics/data/raw/bkg/mu_sideb_generic_charge_eval.pkl")
mix_eval = pandas.read_pickle("../../state/new_physics/data/raw/bkg/mu_sideb_generic_mix_eval.pkl")
all_eval = pandas.concat([charge_eval, mix_eval])

charge_eval = apply_q_squared_veto(charge_eval, Names_of_q_Squared_Vetos().resonances)
charge_train = apply_q_squared_veto(charge_train, Names_of_q_Squared_Vetos().resonances)