Setup

In [None]:
from itertools import product
from pathlib import Path

import numpy
import matplotlib.pyplot as plt
import torch

import helpers

In [None]:
splits = ["train", "val"]
levels = ["gen", "det"]

bin_map = helpers.data.open_bin_map_file()

dset_events_name = "events_binned"
dset_sets_name = "sets_binned"

def make_model_name(level): return f"ebe_{level}" 

dc9_new_phys = -0.82

num_signal_per_set = [8_000, 16_000, 32_000]
num_sets_per_label = {8_000 : 400, 16_000: 200, 32_000 : 100} 
num_sets_sensitivity = 2_000

device = helpers.models.select_device()
loss_fn = torch.nn.MSELoss()
lr = 4e-3
lr_reduce_factor = 0.95
lr_reduce_patience = 0
batch_size = 10_000
epochs = 300
epochs_checkpoint = 5

Save standard scaling constants

In [None]:
split = "train"

for level in levels:

    features, _ = helpers.data.make_events(level, split)
    
    std_scale_mean = torch.mean(features, dim=0) # check
    std_scale_std = torch.std(features, dim=0)

    helpers.data.save_dset_file(std_scale_mean, dset_events_name, level, split, "mean")
    helpers.data.save_dset_file(std_scale_std, dset_events_name, level, split, "std")

Dataset creation

Event

In [None]:
for level, split in product(levels, splits): 

    features, labels = helpers.data.make_events(level, split)

    features = helpers.data.apply_std_scale(features, dset_events_name, level)

    helpers.data.save_dset_file(features, dset_events_name, level, split, "features")
    helpers.data.save_dset_file(labels, dset_events_name, level, split, "labels")

Set

In [None]:
split = "val"
binned_labels = True

for level, num_signal in product(levels, num_signal_per_set): 

    sets_features, sets_labels = helpers.data.make_sets(
        level,
        split,
        num_signal,
        num_sets_per_label[num_signal],
        binned_labels=binned_labels
    )

    sets_features = helpers.data.apply_std_scale(sets_features, dset_events_name, level)

    helpers.data.save_dset_file(sets_features, dset_sets_name, level, split, "features", num_signal_per_set=num_signal)
    helpers.data.save_dset_file(sets_labels, dset_sets_name, level, split, "labels", num_signal_per_set=num_signal)

Sensitivity

In [None]:
split = "val"
binned_labels = True

for level, num_signal in product(levels, num_signal_per_set): 

    sets_features, sets_labels = helpers.data.make_sets(
        level,
        split,
        num_signal,
        num_sets_sensitivity,
        binned_labels=binned_labels,
        label_subset=[numpy.argwhere(bin_map==dc9_new_phys).item()]
    )

    sets_features = helpers.data.apply_std_scale(sets_features, dset_events_name, level)

    helpers.data.save_dset_file(sets_features, dset_sets_name, level, split, "sens_features", num_signal_per_set=num_signal)
    helpers.data.save_dset_file(sets_labels, dset_sets_name, level, split, "sens_labels", num_signal_per_set=num_signal)

Train models

In [None]:
for level in levels:

    model = helpers.models.Event_by_Event_Model()

    model_name = f"ebe_{level}"

    dataset_train = helpers.data.Dataset(dset_events_name, level, "train")
    dataset_val = helpers.data.Dataset(dset_events_name, level, "val")
    
    helpers.models.train(
        model,
        model_name,
        loss_fn,
        dataset_train,
        dataset_val,
        device,
        lr,
        lr_reduce_factor,
        lr_reduce_patience,
        batch_size,
        batch_size,
        epochs,
        epochs_checkpoint
    )

Evaluate models

Linearity and error

In [None]:
for level, num_signal in product(levels, num_signal_per_set):

    model_name = make_model_name(level)
    model = helpers.models.Event_by_Event_Model()
    model.load_state_dict(helpers.models.open_model_state_dict(model_name))

    dataset_val = helpers.data.Dataset(dset_sets_name, level, "val", num_signal_per_set=num_signal)
    
    log_probs = helpers.models.predict_log_probs_event_model(model, dataset_val.features, device)
    preds = helpers.models.predict_values_event_model(log_probs, bin_map, device)

    results_lin = helpers.models.run_linearity_test(preds, dataset_val.labels)
    results_err = helpers.models.run_error_test(preds, dataset_val.labels)

    helpers.models.save_test_result(results_lin, "lin", num_signal, model_name)
    helpers.models.save_test_result(results_err, "err", num_signal, model_name)

Sensitivity

In [None]:
for level, num_signal in product(levels, num_signal_per_set):

    model_name = make_model_name(level)
    model = helpers.models.Event_by_Event_Model()
    model.load_state_dict(helpers.models.open_model_state_dict(model_name))

    dataset_val_sens = helpers.data.Dataset(dset_sets_name, level, "val", num_signal_per_set=num_signal, sensitivity=True)

    log_probs = helpers.models.predict_log_probs_event_model(model, dataset_val_sens.features, device)
    preds = helpers.models.predict_values_event_model(log_probs, bin_map, device)

    results_sens = helpers.models.run_sensitivity_test(preds, dataset_val_sens.labels)

    model_name = make_model_name(level)
    helpers.models.save_test_result(results_sens, "sens", num_signal, model_name)

Plot results

Linearity

In [None]:
fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, layout="compressed")

fancy_level_names = {
    "gen": "Generator", 
    "det" : "Detector", 
}

for (level, num_signal), ax in zip(product(levels, num_signal_per_set), axs.flat):
    
    model_name = make_model_name(level)

    result = helpers.models.open_test_result("lin", num_signal, model_name)

    helpers.plot.plot_linearity(result, ax=ax)

    ax.set_title(
        f"Level: {fancy_level_names[level]}"
        f"\nEvents/set: {num_signal}"
        "\n" + r"Sets/$\delta C_9$: " + f"{num_sets_per_label[num_signal]}", 
        loc="left"
    )

axs.flat[0].legend()
fig.suptitle(f"Deep Sets\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"Actual $\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"Predicted $\delta C_9$", fontsize=11, y=0.45)

plt.savefig(Path("plots").joinpath("deep_sets_grid_lin.png"), bbox_inches="tight")
plt.close()