The full thing!

# Setup

## Imports

In [None]:
import numpy as np
import scipy.special
import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib as mpl
import matplotlib.pyplot as plt

from library.datasets import (
    Signal_Images_Dataset, 
    Signal_Sets_Dataset, 
    Binned_Signal_Dataset
)
from library.models import (
    CNN_Res, 
    Deep_Sets, 
    Event_By_Event_NN
)
from library.nn_training import (
    select_device, 
    train_and_eval
)
from library.predict import (
    Summary_Table,
    make_predictions,
    run_linearity_test,
    run_sensitivity_test,
    calculate_mse_mae
)
from library.plotting import (
    plot_loss_curves, 
    setup_high_quality_mpl_params, 
    plot_prediction_linearity, 
    plot_sensitivity,
    plot_volume_slices
)

## Select device (cuda if available)

In [None]:
device = select_device()

## Setup fancy plotting

In [None]:
setup_high_quality_mpl_params()

## Setup global parameters

In [None]:
dataset_save_dir = "../../state/new_physics/data/processed"
model_dir = "../../state/new_physics/models"

std_scale = True
q_squared_veto = True
balanced_classes = True

set_sizes = [70_000, 24_000, 6_000]

num_sets_per_label = 50
num_sets_per_label_single = 2000

num_image_bins = 10

new_physics_delta_c9_value = -0.82

summary_table = Summary_Table()

# Generator Level

## Setup

In [None]:
level = "gen"

## Shawn's Method

#### Datasets

In [None]:
regenerate = False

train_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="train", 
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=num_events_per_set,
        regenerate=regenerate,
    ) 
    for num_events_per_set in set_sizes
}

eval_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="eval", 
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=num_events_per_set,
        regenerate=regenerate,
    ) 
    for num_events_per_set in set_sizes
}

single_label_eval_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="eval", 
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label_single,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        labels_to_sample=[new_physics_delta_c9_value],
        extra_description=f"{num_events_per_set}_single",
        regenerate=regenerate,
    ) 
    for num_events_per_set in set_sizes
}

#### Models

In [None]:
models = {
    num_events_per_set : CNN_Res(
        model_dir, 
        extra_description=f"v2_{num_events_per_set}"
    )
    for num_events_per_set in set_sizes
}

### Peek at features

In [None]:
num_events_per_set = 24_000

dset = train_datasets[num_events_per_set]
dset.load()

plot_volume_slices(
     dset.features[0], 
     n_slices=3, 
     note=r"$\delta C_9$ : "+f"{dset.labels[0]}"
)
plt.show()
plt.close()

### Model Training

In [None]:
learning_rate = 4e-4
epochs = 80
train_batch_size = 32
eval_batch_size = 32

for num_events_per_set in set_sizes:

    model = models[
        num_events_per_set
    ]

    loss_fn = nn.L1Loss()
    optimizer = Adam(
        model.parameters(), 
        lr=learning_rate
    )

    train_dataset = train_datasets[
        num_events_per_set
    ]
    eval_dataset = eval_datasets[
        num_events_per_set
    ]
    train_dataset.load()
    eval_dataset.load()

    train_and_eval(
        model, 
        train_dataset, 
        eval_dataset, 
        loss_fn, 
        optimizer, 
        epochs, 
        train_batch_size, 
        eval_batch_size, 
        device, 
        move_data=True,
        scheduler=ReduceLROnPlateau(
            optimizer, 
            factor=0.9, 
            patience=1
        ),
        checkpoint_epochs=5,
    )

    _, ax = plt.subplots()
    plot_loss_curves(
        model.loss_table,
        ax,
        start_epoch=0,
        log_scale=True,
    )
    plt.show()
    plt.close()

### Evaluation

#### Linearity

In [None]:
for num_events_per_set in set_sizes:
    
    model = models[
        num_events_per_set
    ]
    model.load_final()
    model.to(device)
    model.eval()

    eval_dataset = eval_datasets[
        num_events_per_set
    ]
    eval_dataset.load()

    predictions = make_predictions(
        model, 
        eval_dataset.features,
        device,
    )

    mse, mae = calculate_mse_mae(
        predictions, 
        eval_dataset.labels,
    )
    summary_table.add_item(
        "Images", 
        "MSE", 
        num_events_per_set, 
        mse,
    )
    summary_table.add_item(
        "Images", 
        "MAE", 
        num_events_per_set, 
        mae,
    )

    (
        unique_labels, 
        avgs, 
        stds,
    ) = run_linearity_test(
        predictions, 
        eval_dataset.labels
    )

    eval_dataset.unload()

    fig, ax = plt.subplots()
    
    plot_prediction_linearity(
        ax,
        unique_labels.detach().cpu().numpy(),
        avgs.detach().cpu().numpy(),
        stds.detach().cpu().numpy(),
        note=(
            f"Images ({num_image_bins} bins), {level}., "
            + f"{num_sets_per_label} boots., "
            + f"{num_events_per_set} events/boots."
        ),
    )

    plt.show()
    plt.close()

#### Sensitivity

In [None]:
for num_events_per_set in set_sizes:
    
    eval_dataset = single_label_eval_datasets[
        num_events_per_set
    ]
    eval_dataset.load()

    model = models[
        num_events_per_set
    ]
    model.load_final()
    model.to(device)
    model.eval()

    predictions = make_predictions(
        model, 
        eval_dataset.features, 
        device, 
    )

    mean, std, bias = run_sensitivity_test(
        predictions, 
        new_physics_delta_c9_value
    )
    summary_table.add_item(
        "Images", 
        "Mean at NP", 
        num_events_per_set, 
        mean
    )
    summary_table.add_item(
        "Images", 
        "Std. at NP", 
        num_events_per_set, 
        std
    )
    summary_table.add_item(
        "Images", 
        "Bias at NP", 
        num_events_per_set, 
        bias
    )

    eval_dataset.unload()

    fig, ax = plt.subplots()

    plot_sensitivity(
        ax,
        predictions,
        new_physics_delta_c9_value,
        note=(
            f"Images ({num_image_bins} bins), {level}., " 
            + f"{num_sets_per_label_single} boots., " 
            + f"{num_events_per_set} events/boots."
        ), 
    )

    plt.show()
    plt.close()

#### Summary

In [None]:
summary_table

In [None]:
print(
    summary_table[["MSE", "MAE"]]
    .to_latex(float_format="%.3f")
)

## Deep Sets

### Setup

#### Datasets

In [None]:
regenerate = False

train_datasets = {
    num_events_per_set : Signal_Sets_Dataset(
        level=level,
        split="train",
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        binned=False,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=f"unbinned_{num_events_per_set}",
        regenerate=regenerate
    )
    for num_events_per_set in set_sizes
}

eval_datasets = {
    num_events_per_set : Signal_Sets_Dataset(
        level=level,
        split="eval",
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        binned=False,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=f"unbinned_{num_events_per_set}",
        regenerate=regenerate
    )
    for num_events_per_set in set_sizes
}

single_label_eval_datasets = {
    num_events_per_set : Signal_Sets_Dataset(
        level=level,
        split="eval",
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label_single,
        binned=False,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        labels_to_sample=[new_physics_delta_c9_value],
        extra_description=f"unbinned_{num_events_per_set}_single",
        regenerate=regenerate
    )
    for num_events_per_set in set_sizes
}


#### Models

In [None]:
models = {
    num_events_per_set : Deep_Sets(
        model_dir, 
        extra_description=num_events_per_set
    )
    for num_events_per_set in set_sizes
}

### Model Training

In [None]:
learning_rate = 4e-4
epochs = 80
train_batch_size = 32
eval_batch_size = 32

for num_events_per_set in set_sizes:

    model = models[
        num_events_per_set
    ]

    loss_fn = nn.MSELoss() # trained with L1 loss
    optimizer = Adam(
        model.parameters(), 
        lr=learning_rate,
    )

    train_dataset = train_datasets[
        num_events_per_set
    ]
    eval_dataset = eval_datasets[
        num_events_per_set
    ]
    train_dataset.load()
    eval_dataset.load()

    train_and_eval(
        model, 
        train_dataset, 
        eval_dataset, 
        loss_fn, 
        optimizer, 
        epochs, 
        train_batch_size, 
        eval_batch_size, 
        device, 
        move_data=True,
        scheduler=ReduceLROnPlateau(
            optimizer, 
            factor=0.9, 
            patience=1
        ),
        checkpoint_epochs=5,
    )

    _, ax = plt.subplots()
    plot_loss_curves(
        model.loss_table,
        ax,
        start_epoch=0,
        log_scale=True,
    )
    plt.show()
    plt.close()

### Evaluation

#### Linearity

In [None]:
for num_events_per_set in set_sizes:

    model = models[
        num_events_per_set
    ]
    model.load_final()
    model.to(device)
    model.eval()
    
    eval_dataset = eval_datasets[
        num_events_per_set
    ]

    eval_dataset.load()

    predictions = make_predictions(
        model, 
        eval_dataset.features, 
        device,
    )

    (
        unique_labels, 
        avgs, 
        stds,
    ) = run_linearity_test(
        predictions, 
        eval_dataset.labels
    )

    (
        mse, 
        mae
    ) = calculate_mse_mae(
        predictions, 
        eval_dataset.labels
    )

    summary_table.add_item("Deep Sets", "MSE", num_events_per_set, mse)
    summary_table.add_item("Deep Sets", "MAE", num_events_per_set, mae)

    eval_dataset.unload()

    fig, ax = plt.subplots()
    
    plot_prediction_linearity(
        ax,
        unique_labels.detach().cpu().numpy(),
        avgs.detach().cpu().numpy(),
        stds.detach().cpu().numpy(),
        note=(
            f"Deep Sets, {level}., "
            + f"{num_sets_per_label} boots., "
            + f"{num_events_per_set} events/boots."
        ),
    )

    plt.show()
    plt.close()

#### Sensitivity

In [None]:
for num_events_per_set in set_sizes:
    
    eval_dataset = single_label_eval_datasets[
        num_events_per_set
    ]
    eval_dataset.load()

    model = models[
        num_events_per_set
    ]
    model.load_final()
    model.to(device)
    model.eval()

    predictions = make_predictions(
        model, 
        eval_dataset.features, 
        device,
    )

    (
        mean, 
        std, 
        bias,    
    ) = run_sensitivity_test(
        predictions, 
        new_physics_delta_c9_value
    )

    summary_table.add_item("Deep Sets", "Mean at NP", num_events_per_set, mean)
    summary_table.add_item("Deep Sets", "Std. at NP", num_events_per_set, std)
    summary_table.add_item("Deep Sets", "Bias at NP", num_events_per_set, bias)

    eval_dataset.unload()

    fig, ax = plt.subplots()

    plot_sensitivity(
        ax,
        predictions,
        new_physics_delta_c9_value,
        note=(
            f"Deep Sets, {level}., " 
            + f"{num_sets_per_label} boots., " 
            + f"{num_events_per_set} events/boots."
        ), 
    )

    plt.show()
    plt.close()

#### Summary

In [None]:
summary_table

In [None]:
print(
    summary_table[["MSE", "MAE"]]
    .to_latex(float_format="%.3f")
)

## Event-by-event Method

### Setup

#### Datasets

In [None]:
regenerate = False

train_events_dataset = Binned_Signal_Dataset(
    level=level,
    split="train",
    save_dir=dataset_save_dir,
    q_squared_veto=q_squared_veto,
    std_scale=std_scale,
    balanced_classes=balanced_classes,
    shuffle=True,
    extra_description=None,
    regenerate=regenerate
)

eval_events_dataset = Binned_Signal_Dataset(
    level=level,
    split="eval",
    save_dir=dataset_save_dir,
    q_squared_veto=q_squared_veto,
    std_scale=std_scale,
    balanced_classes=balanced_classes,
    shuffle=True,
    extra_description=None,
    regenerate=regenerate
)

eval_sets_datasets = {
    num_events_per_set : Signal_Sets_Dataset(
        level=level,
        split="eval",
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        binned=True,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=f"binned_{num_events_per_set}",
        regenerate=regenerate
    )
    for num_events_per_set in set_sizes
}

single_label_eval_sets_datasets = {
    num_events_per_set : Signal_Sets_Dataset(
        level=level,
        split="eval",
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label_single,
        binned=True,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        labels_to_sample=[new_physics_delta_c9_value],
        extra_description=f"binned_{num_events_per_set}_single",
        regenerate=regenerate
    )
    for num_events_per_set in set_sizes
}


#### Model

In [None]:
model = Event_By_Event_NN(model_dir)

### Model Training

In [None]:
learning_rate = 3e-3
epochs = 200

train_batch_size = 10_000
eval_batch_size = 10_000

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(
    model.parameters(), 
    lr=learning_rate
)

train_dataset.load()
eval_dataset.load()

loss_table = train_and_eval(
    model, 
    train_events_dataset, 
    eval_events_dataset, 
    loss_fn, 
    optimizer, 
    epochs, 
    train_batch_size, 
    eval_batch_size, 
    device, 
    move_data=True,
    scheduler=ReduceLROnPlateau(
        optimizer, 
        factor=0.95, 
        threshold=0, 
        patience=0, 
        eps=1e-9
    ),
    checkpoint_epochs=5,
)

_, ax = plt.subplots()
plot_loss_curves(
    model.loss_table,
    ax,
    start_epoch=0,
    log_scale=True,
)
plt.show()
plt.close()

### Evaluation

#### Linearity

In [None]:
model.load_final()
model.to(device)
model.eval()

for num_events_per_set in set_sizes:

    eval_sets_dataset = eval_sets_datasets[
        num_events_per_set
    ]
    eval_sets_dataset.load()

    predictions = make_predictions(
        model, 
        eval_sets_dataset.features, 
        device, 
        event_by_event=True, 
        bin_values=eval_sets_dataset.bin_values
    )

    assert (
        predictions.shape 
        == eval_sets_dataset.labels.shape
    )

    unbinned_labels = (
        eval_sets_dataset
        .bin_values[
            eval_sets_dataset.labels.int()
        ]
    )

    (
        mse, 
        mae
    ) = calculate_mse_mae(
        predictions, 
        unbinned_labels
    )

    summary_table.add_item(
        "Event by event", 
        "MSE", 
        num_events_per_set, 
        mse
    )

    summary_table.add_item(
        "Event by event", 
        "MAE", 
        num_events_per_set, 
        mae
    )
        
    (
        unique_labels, 
        avgs, 
        stds
    ) = run_linearity_test(
        predictions, 
        eval_sets_dataset.labels
    )

    eval_sets_dataset.unload()

    fig, ax = plt.subplots()
    
    plot_prediction_linearity(
        ax,
        unique_labels.detach().cpu().numpy(),
        avgs.detach().cpu().numpy(),
        stds.detach().cpu().numpy(),
        note=(
            f"Deep Sets, {level}., "
            + f"{num_sets_per_label} boots., "
            + f"{num_events_per_set} events/boots."
        ),
    )

    plt.show()
    plt.close()

#### Sensitivity

In [None]:
model.load_final()
model.to(device)
model.eval()

for num_events_per_set in set_sizes:
    
    eval_sets_dataset = single_label_eval_sets_datasets[
        num_events_per_set
    ]
    eval_sets_dataset.load()

    predictions = make_predictions(
        model, 
        eval_sets_dataset.features, 
        device, 
        event_by_event=True, 
        bin_values=eval_sets_dataset
            .bin_values,
    )
    
    (
        mean, 
        std, 
        bias,
    ) = run_sensitivity_test(
        predictions, 
        new_physics_delta_c9_value
    )

    summary_table.add_item("Event by event", "Mean at NP", num_events_per_set, mean)
    summary_table.add_item("Event by event", "Std. at NP", num_events_per_set, std)
    summary_table.add_item("Event by event", "Bias at NP", num_events_per_set, bias)

    eval_sets_dataset.unload()

    fig, ax = plt.subplots()

    plot_sensitivity(
        ax, 
        predictions, 
        new_physics_delta_c9_value, 
        note=(
            f"Deep Sets, {level}., " 
            + f"{num_sets_per_label} boots., " 
            + f"{num_events_per_set} events/boots."
        ),
    )
    
    plt.show()
    plt.close()

#### Summary

In [None]:
summary_table

In [None]:
print(
    summary_table[["MSE", "MAE"]]
    .to_latex(float_format="%.3f")
)

## Summary

### Table

In [None]:
summary_table

In [None]:
print(
    summary_table[["MSE", "MAE",]]
    .to_latex(float_format="%.3f")
)

In [None]:
print(
    summary_table[["Std. at NP", "Bias at NP"]]
    .to_latex(float_format="%.3f")
)

### Plots

In [None]:
summary_table.index.unique("Method")

In [None]:
linestyles = ["-D", "-o", "-s"]

y_lims = [(0, None), (0, None), (0, None), (None, None), (-0.15, 0.15)]

for col, y_lim in zip(summary_table.columns, y_lims):

    fig, ax = plt.subplots()
    
    for method, style in zip(summary_table.index.unique("Method"), linestyles):
   
        y = summary_table.loc[method, col]
        x = y.index
        
        ax.plot(x, y, style, label=f"{method}")

    ax.set_ylim(y_lim)
    ax.set_title(f"{col}")
    ax.set_xlabel("Number of events / set")
    ax.legend()

    plt.show()
    plt.close()

# Detector Level

In [None]:
level = "det"

regenerate = False

train_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="train", 
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=num_events_per_set,
        regenerate=regenerate,
    ) 
    for num_events_per_set in set_sizes
}

eval_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="eval", 
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        extra_description=num_events_per_set,
        regenerate=regenerate,
    ) 
    for num_events_per_set in set_sizes
}

single_label_eval_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="eval", 
        save_dir=dataset_save_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label_single,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        labels_to_sample=[new_physics_delta_c9_value],
        extra_description=f"{num_events_per_set}_single",
        regenerate=regenerate,
    ) 
    for num_events_per_set in set_sizes
}







## Shawn's Method

### Datasets

### Model Training

### Evaluation

## Deep Sets

### Datasets

### Model Training

### Evaluation

## Event-by-event Method

### Datasets

### Model Training

### Evaluation

# Detector Level with Backgrounds

## Shawn's Method

## Deep Sets

## Event-by-event Method

In [None]:




def predict_log_probabilities_by_label(x, y, model):
    """
    Predict the log probability of each class for each subset of same labeled events.
    
    x : A torch tensor of features of events (from multiple labels).
    y : A torch tensor of event labels.
    """
    labels = np.unique(y.cpu())
    log_probabilities = []
    for label in labels:
        features_label = x[y==label]
        log_probabilities_label = predict_log_probabilities(features_label, model).cpu().numpy()
        log_probabilities.append(np.expand_dims(log_probabilities_label, axis=0))
    log_probabilities = np.concatenate(log_probabilities, axis=0)
    assert log_probabilities.shape[0] == len(labels)
    return log_probabilities


def calculate_predicted_expected_value_by_label(predictions, bin_values):
    """
    Calculate the predicted expected binned value for each subset of same labeled events, given 
    the predicted probability distribution for each subset.

    predictions : numpy array of predicted probability distributions for each label.
    bin_values : numpy array of the value corresponding to each bin.
    """
    bin_value_shift = np.abs(np.min(bin_values)) + 1
    shifted_bin_values = bin_values + bin_value_shift
    log_shifted_bin_values = np.tile(np.log(shifted_bin_values), (len(shifted_bin_values), 1))
    log_shifted_expected_values = scipy.special.logsumexp(predictions + log_shifted_bin_values, axis=1)
    shifted_expected_values = np.exp(log_shifted_expected_values)
    expected_values = shifted_expected_values - bin_value_shift
    return expected_values


def plot_log_probabilities_over_labels(fig, ax, predictions, bin_values, cmap=plt.cm.viridis):
    """
    Plot the predicted log probability of each class for each subset of same labeled events.

    predictions : A numpy array of set probabilities (rows correspond to labels, columns correspond to class predictions).
    bin_values : A numpy array of the value each bin represents. 
    """

    color_bounds = np.append(bin_values, bin_values[-1] + (bin_values[-1] - bin_values[-2]))
    color_norm = mpl.colors.BoundaryNorm(color_bounds, cmap.N)

    for value, curve in zip(bin_values, predictions):
        ax.plot(bin_values, curve, color=cmap(color_norm(value)))

    fig.colorbar(mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap), ax=ax, label=r"Actual $\delta C_9$")
    ax.set_xlabel(r"$\delta C_9$")
    ax.set_ylabel(r"$\log p(\delta C_9 | x_1, ..., x_N)$")


def plot_expected_value_over_labels(ax, expected_values, bin_values):
    """
    Plot the predicted expected value for each label.
    """
    ax.scatter(bin_values, expected_values, label="Prediction", color="firebrick", s=16, zorder=5)
    ax.plot(
        bin_values, bin_values,
        label="Ref. Line (Slope = 1)",
        color="grey",
        linewidth=0.5,
        zorder=0
    )
    ax.set_xlabel(r"Actual $\delta C_9$")
    ax.set_ylabel(r"Predicted $\delta C_9$")
    ax.legend()
