Setup

In [None]:
from itertools import product 
from pathlib import Path

import matplotlib.pyplot as plt
import torch

import helpers

In [None]:
splits = ["train", "val"]
levels = ["gen", "det", "det_bkg"]

dset_name = "sets_unbinned"

def make_model_name(level, num_signal): return f"deep_sets_{level}_{num_signal}"

dc9_new_phys = -0.82

num_signal_per_set = [8_000, 16_000, 32_000]
num_sets_per_label = {8_000 : 400, 16_000: 200, 32_000 : 100} 
num_sets_sensitivity = 2_000
bkg_signal_ratio = 0.79
charge_bkg_fraction = 0.57

device = helpers.models.select_device()
loss_fn = torch.nn.MSELoss()
lr = 3e-4
lr_reduce_factor = 0.8
lr_reduce_patience = 3
batch_sizes = {8_000 : 32, 16_000 : 64, 32_000 : 128}
epochs = 100
epochs_checkpoint = 1

Save standard scaling constants

In [None]:
split = "train"

for level, num_signal in product(levels, num_signal_per_set):

        sets_features, sets_labels = helpers.data.make_sets(
            level,
            split,
            num_signal,
            num_sets_per_label[num_signal],
            bkg_signal_ratio=bkg_signal_ratio,
            charge_bkg_fraction=charge_bkg_fraction
        )

        std_scale_mean = torch.mean(sets_features, dim=(0,1))
        std_scale_std = torch.std(sets_features, dim=(0,1))

        helpers.data.save_dset_file(std_scale_mean, dset_name, level, split, "mean", num_signal_per_set=num_signal)
        helpers.data.save_dset_file(std_scale_std, dset_name, level, split, "std", num_signal_per_set=num_signal)

Dataset creation

In [None]:
for level, num_signal, split in product(levels, num_signal_per_set, splits): 

    sets_features, sets_labels = helpers.data.make_sets(
        level,
        split,
        num_signal,
        num_sets_per_label[num_signal],
        bkg_signal_ratio=bkg_signal_ratio,
        charge_bkg_fraction=charge_bkg_fraction
    )

    sets_features = helpers.data.apply_std_scale(sets_features, dset_name, level, num_signal_per_set=num_signal)

    helpers.data.save_dset_file(sets_features, dset_name, level, split, "features", num_signal_per_set=num_signal)
    helpers.data.save_dset_file(sets_labels, dset_name, level, split, "labels", num_signal_per_set=num_signal)

Sensitivity datasets

In [None]:
split = "val"

for level, num_signal in product(levels, num_signal_per_set):

    sets_features, sets_labels = helpers.data.make_sets(
        level,
        split,
        num_signal,
        num_sets_sensitivity,
        label_subset=[dc9_new_phys],
        bkg_signal_ratio=bkg_signal_ratio,
        charge_bkg_fraction=charge_bkg_fraction
    )

    sets_features = helpers.data.apply_std_scale(sets_features, dset_name, level, num_signal_per_set=num_signal)

    helpers.data.save_dset_file(sets_features, dset_name, level, split, "sens_features", num_signal_per_set=num_signal)
    helpers.data.save_dset_file(sets_labels, dset_name, level, split, "sens_labels", num_signal_per_set=num_signal)

Train models

In [None]:
for level, num_signal in product(levels, num_signal_per_set):

    model = helpers.models.Deep_Sets_Model()

    model_name = make_model_name(level, num_signal)

    dataset_train = helpers.data.Dataset(dset_name, level, "train", num_signal_per_set=num_signal)
    dataset_val = helpers.data.Dataset(dset_name, level, "val", num_signal_per_set=num_signal)
    
    helpers.models.train(
        model,
        model_name,
        loss_fn,
        dataset_train,
        dataset_val,
        device,
        lr,
        lr_reduce_factor,
        lr_reduce_patience,
        batch_sizes[num_signal],
        batch_sizes[num_signal],
        epochs,
        epochs_checkpoint
    )

Evaluate models

Linearity and error

In [None]:
for level, num_signal in product(levels, num_signal_per_set):

    model_name = make_model_name(level, num_signal)
    model = helpers.models.Deep_Sets_Model()
    model.load_state_dict(helpers.models.open_model_state_dict(model_name))
    
    dataset_val = helpers.data.Dataset(dset_name, level, "val", num_signal_per_set=num_signal)
    
    preds = helpers.models.predict_values_set_model(model, dataset_val.features, device)

    results_lin = helpers.models.run_linearity_test(preds, dataset_val.labels)
    results_err = helpers.models.run_error_test(preds, dataset_val.labels)

    helpers.models.save_test_result(results_lin, "lin", num_signal, model_name)
    helpers.models.save_test_result(results_err, "err", num_signal, model_name)

Sensitivity

In [None]:
for level, num_signal in product(levels, num_signal_per_set):

    model_name = make_model_name(level, num_signal)
    model = helpers.models.Deep_Sets_Model()
    model.load_state_dict(helpers.models.open_model_state_dict(model_name))

    dataset_val_sens = helpers.data.Dataset(dset_name, level, "val", num_signal_per_set=num_signal, sensitivity=True)

    preds = helpers.models.predict_values_set_model(model, dataset_val_sens.features, device)

    results_sens = helpers.models.run_sensitivity_test(preds, dataset_val_sens.labels)

    helpers.models.save_test_result(results_sens, "sens", num_signal, model_name)

Plot results

Linearity

In [None]:
fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, layout="compressed")

fancy_level_names = {
    "gen": "Generator", 
    "det" : "Detector", 
    "det_bkg" : "Detector and Bkg."
}

for (level, num_signal), ax in zip(product(levels, num_signal_per_set), axs.flat):
    
    model_name = make_model_name(level, num_signal)

    result = helpers.models.open_test_result("lin", num_signal, model_name)

    helpers.plot.plot_linearity(result, ax=ax)

    ax.set_title(
        f"Level: {fancy_level_names[level]}"
        f"\nEvents/set: {num_signal}"
        "\n" + r"Sets/$\delta C_9$: " + f"{num_sets_per_label[num_signal]}", 
        loc="left"
    )

axs.flat[0].legend()
fig.suptitle(f"Deep Sets\n", x=0.02, horizontalalignment="left")
fig.supxlabel(r"Actual $\delta C_9$", fontsize=11, x=0.56, y=-0.06)
fig.supylabel(r"Predicted $\delta C_9$", fontsize=11, y=0.45)

plt.savefig(Path("plots").joinpath("deep_sets_grid_lin.png"), bbox_inches="tight")
plt.close()