Training a neural network to predict $\delta C_9$ on an event-by-event basis (classification)

Import Libraries

In [2]:
from pathlib import Path
import pickle

import numpy as np
from scipy.special import logsumexp

import matplotlib as mpl
from matplotlib import pyplot as plt

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss
from torch.nn.functional import log_softmax

from library.plotting import setup_high_quality_mpl_params, plot_loss_curves, plot_prediction_linearity, make_plot_note
from library.datasets import Aggregated_Signal_Binned_Dataset
from library.util import bootstrap_over_bins
from library.nn_training import select_device, train_and_eval


# setup_high_quality_mpl_params()


Define Helper Functions

In [3]:

def predict_log_probabilities(x, model):
    """
    Predict the log probability of each class, given a set of events.

    x : A torch tensor of features of events.
    """
    model.eval()
    with torch.no_grad():
        event_logits = model(x)
        event_log_probabilities = log_softmax(event_logits, dim=1)
        set_logits = torch.sum(event_log_probabilities, dim=0)
        set_log_probabilities = log_softmax(set_logits, dim=0)
    return set_log_probabilities


def predict_log_probabilities_by_label(x, y, model):
    """
    Predict the log probability of each class for each subset of same labeled events.
    
    x : A torch tensor of features of events (from multiple labels).
    y : A torch tensor of event labels.
    """
    labels = np.unique(y.cpu())
    log_probabilities = []
    for label in labels:
        features_label = x[y==label]
        log_probabilities_label = predict_log_probabilities(features_label, model).cpu().numpy()
        log_probabilities.append(np.expand_dims(log_probabilities_label, axis=0))
    log_probabilities = np.concatenate(log_probabilities, axis=0)
    assert log_probabilities.shape == (len(labels), len(labels))
    return log_probabilities


def calculate_predicted_expected_value_by_label(predictions, bin_values):
    """
    Calculate the predicted expected binned value for each subset of same labeled events, given 
    the predicted probability distribution for each subset.

    predictions : numpy array of predicted probability distributions for each label.
    bin_values : numpy array of the value corresponding to each bin.
    """
    bin_value_shift = np.abs(np.min(bin_values)) + 1
    shifted_bin_values = bin_values + bin_value_shift
    log_shifted_bin_values = np.tile(np.log(shifted_bin_values), (len(shifted_bin_values), 1))
    log_shifted_expected_values = logsumexp(predictions + log_shifted_bin_values, axis=1)
    shifted_expected_values = np.exp(log_shifted_expected_values)
    expected_values = shifted_expected_values - bin_value_shift
    return expected_values


def plot_log_probabilities_over_labels(fig, ax, predictions, bin_values, cmap=plt.cm.viridis):
    """
    Plot the predicted log probability of each class for each subset of same labeled events.

    predictions : A numpy array of set probabilities (rows correspond to labels, columns correspond to class predictions).
    bin_values : A numpy array of the value each bin represents. 
    """

    color_bounds = np.append(bin_values, bin_values[-1] + (bin_values[-1] - bin_values[-2]))
    color_norm = mpl.colors.BoundaryNorm(color_bounds, cmap.N)

    for value, curve in zip(bin_values, predictions):
        ax.plot(bin_values, curve, color=cmap(color_norm(value)))

    fig.colorbar(mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap), ax=ax, label=r"Actual $\delta C_9$")
    ax.set_xlabel(r"$\delta C_9$")
    ax.set_ylabel(r"$\log p(\delta C_9 | x_1, ..., x_N)$")


def plot_expected_value_over_labels(ax, expected_values, bin_values):
    """
    Plot the predicted expected value for each label.
    """
    ax.scatter(bin_values, expected_values, label="Prediction", color="firebrick", s=16, zorder=5)
    ax.plot(
        bin_values, bin_values,
        label="Ref. Line (Slope = 1)",
        color="grey",
        linewidth=0.5,
        zorder=0
    )
    ax.set_xlabel(r"Actual $\delta C_9$")
    ax.set_ylabel(r"Predicted $\delta C_9$")
    ax.legend()


Select Device

In [None]:
device = select_device()

Define Model

In [5]:
class Event_By_Event_NN(nn.Module):
    def __init__(self, nickname):
        super().__init__()

        self.nickname = nickname
        self.model_dir = Path("../../state/new_physics/models")
        self.loss_table = self.make_empty_loss_table()

        self.base = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 44),
         )
        
        self.double()

    def forward(self, x):
        event_logits = self.base(x)
        return event_logits
        
    def make_final_save_path(self):
        final_save_path = self.model_dir.joinpath(f"{self.nickname}.pt")
        return final_save_path 
    
    def save_final(self):
        final_save_path = self.make_final_save_path()
        torch.save(self.state_dict(), final_save_path)

    def load_final(self):
        file_path = self.make_final_save_path()
        self.load_state_dict(torch.load(file_path, weights_only=True))
    
    def make_checkpoint_save_path(self, epoch_number):
        checkpoint_save_name = self.nickname + f"_epoch_{epoch_number}"
        checkpoint_save_path = self.model_dir.joinpath(f"{checkpoint_save_name}.pt")
        return checkpoint_save_path
        
    def save_checkpoint(self, epoch_number):
        checkpoint_save_path = self.make_checkpoint_save_path(epoch_number)
        torch.save(self.state_dict(), checkpoint_save_path)

    def load_checkpoint(self, epoch_number):
        file_path = self.make_checkpoint_save_path(epoch_number)
        self.load_state_dict(torch.load(file_path, weights_only=True))

    def make_loss_table_file_path(self):
        file_name = f"{self.nickname}_loss.pkl"
        file_path = self.model_dir.joinpath(file_name)
        return file_path
    
    def save_loss_table(self):
        file_path = self.make_loss_table_file_path()
        with open(file_path, "wb") as handle:
            pickle.dump(self.loss_table, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def load_loss_table(self):
        file_path = self.make_loss_table_file_path()
        with open(file_path, "rb") as handle:
            loss_table = pickle.load(handle)
        return loss_table
    
    def append_to_loss_table(self, epoch, train_loss, eval_loss):
        self.loss_table["epoch"].append(epoch)
        self.loss_table["train_loss"].append(train_loss)
        self.loss_table["eval_loss"].append(eval_loss)
        assert len(self.loss_table["epoch"]) == len(self.loss_table["train_loss"]) == len(self.loss_table["eval_loss"])

    def make_empty_loss_table(self):
        """Create an empty loss table."""
        empty_loss_table = {"epoch":[], "train_loss":[], "eval_loss":[]}
        return empty_loss_table
    
    def clear_loss_table(self):
        self.loss_table = self.make_empty_loss_table()



Load / Generate Datasets

In [None]:
regenerate = False

level = "gen"
save_dir = "../../state/new_physics/data/processed"
raw_signal_dir = "../../state/new_physics/data/raw/signal"

std_scale = True
q_squared_veto = True

datasets = {
    "train": Aggregated_Signal_Binned_Dataset(level=level, split="train", save_dir=save_dir),
    "eval": Aggregated_Signal_Binned_Dataset(level=level, split="eval", save_dir=save_dir),
}

if regenerate:
    datasets["train"].generate(
        raw_trials=range(1,20), 
        raw_signal_dir=raw_signal_dir, 
        std_scale=std_scale, 
        q_squared_veto=q_squared_veto
    )
    datasets["eval"].generate(
        raw_trials=range(20,30), 
        raw_signal_dir=raw_signal_dir, 
        std_scale=std_scale, 
        q_squared_veto=q_squared_veto
    )

datasets["train"].load(device)
datasets["eval"].load(device)

np.testing.assert_equal(datasets["train"].bin_values, datasets["eval"].bin_values)

Train / Load Model

In [7]:
retrain = False

model = Event_By_Event_NN("ebe_with_checkpoints")

if retrain:
    learning_rate = 3e-3
    epochs = 100
    train_batch_size = 10_000
    eval_batch_size = 10_000
    loss_fn = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    loss_table = train_and_eval(
        model, 
        datasets["train"], datasets["eval"], 
        loss_fn, optimizer, 
        epochs, 
        train_batch_size, eval_batch_size, 
        device, 
        move_data=True,
        scheduler= ReduceLROnPlateau(optimizer, factor=0.95, patience=0),
        checkpoint_epochs=5,
    )
    _, ax = plt.subplots()
    plot_epoch_start = 0
    plot_loss_curves(loss_table["epoch"][plot_epoch_start:], loss_table["train_loss"][plot_epoch_start:], loss_table["eval_loss"][plot_epoch_start:], ax)
    ax.set_yscale("log")
    plt.show()
else:
    # model.load_final()
    model.load_checkpoint(epoch_number=10)
    model.to(device)

In [None]:
loss_table = model.load_loss_table()

_, ax = plt.subplots()
plot_epoch_start = 0
plot_loss_curves(loss_table["epoch"][plot_epoch_start:], loss_table["train_loss"][plot_epoch_start:], loss_table["eval_loss"][plot_epoch_start:], ax)
# ax.set_yscale("log")
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(loss_table["eval_loss"][50:])
# ax.plot(loss_table["train_loss"])

plt.show()

Evaluate Model

In [None]:
for ep in range(0, 100, 5):

    model.load_checkpoint(epoch_number=ep)
    model.to(device)

    fig, ax = plt.subplots(layout="constrained")

    log_probs = predict_log_probabilities_by_label(datasets["eval"].features, datasets["eval"].labels, model)

    expected_values = calculate_predicted_expected_value_by_label(log_probs, datasets["eval"].bin_values)
    plot_expected_value_over_labels(ax, expected_values, datasets["eval"].bin_values)

    # plot_log_probabilities_over_labels(fig, ax, log_probs, datasets["eval"].bin_values)

    make_plot_note(ax, f"About 77k events/prediction - Epoch: {ep}", fontsize="large")

    plt.savefig(f"plots_tmp/expected_ep_{ep}.png", bbox_inches="tight")

model.load_final()
model.to(device)

fig, ax = plt.subplots(layout="constrained")

log_probs = predict_log_probabilities_by_label(datasets["eval"].features, datasets["eval"].labels, model)

expected_values = calculate_predicted_expected_value_by_label(log_probs, datasets["eval"].bin_values)
plot_expected_value_over_labels(ax, expected_values, datasets["eval"].bin_values)

# plot_log_probabilities_over_labels(fig, ax, log_probs, datasets["eval"].bin_values)

make_plot_note(ax, "About 77k events/prediction - Epoch: 99", fontsize="large")

plt.savefig("plots_tmp/expected_ep_99.png", bbox_inches="tight")






In [None]:
# On evaluation data

fig, ax = plt.subplots(layout="constrained")

log_probs = predict_log_probabilities_by_label(datasets["eval"].features, datasets["eval"].labels, model)
plot_log_probabilities_over_labels(fig, ax, log_probs, datasets["eval"].bin_values)

make_plot_note(ax, "About 77k events/curve", fontsize="large")

In [None]:
expected_values = calculate_predicted_expected_value_by_label(log_probs, datasets["eval"].bin_values)
plot_expected_value_over_labels(expected_values, datasets["eval"].bin_values)

In [None]:
# On training data

fig, ax = plt.subplots(layout="constrained")

log_probs = predict_log_probabilities_by_label(datasets["train"].features, datasets["train"].labels, model)
plot_log_probabilities_over_labels(fig, ax, log_probs, datasets["train"].bin_values)

In [None]:
expected_values = calculate_predicted_expected_value_by_label(log_probs, datasets["train"].bin_values)
plot_expected_value_over_labels(expected_values, datasets["train"].bin_values)

Other stuff

Linearity

In [30]:
n_trials = 10
n_events_per_trial = 70_000

expected_values_all_trials = []
for _ in range(n_trials):

    boot_x, boot_y = bootstrap_over_bins(
        datasets["eval"].features.cpu().numpy(), 
        datasets["eval"].labels.cpu().numpy(), 
        n_events_per_trial,
    )
    boot_x = torch.from_numpy(boot_x).to(device)
    boot_y = torch.from_numpy(boot_y).to(device)

    log_probs = predict_log_probabilities_by_label(boot_x, boot_y, model)
    expected_values = calculate_predicted_expected_value_by_label(log_probs, datasets["eval"].bin_values)
    expected_values_all_trials.append(np.expand_dims(expected_values, axis=0))

expected_values_all_trials = np.concat(expected_values_all_trials)

expected_values_all_trials_means = np.mean(expected_values_all_trials, axis=0)
expected_values_all_trials_stds = np.std(expected_values_all_trials, axis=0)

In [None]:
fig, ax = plt.subplots()

plot_prediction_linearity(
    ax,
    datasets["eval"].bin_values,
    expected_values_all_trials_means,
    expected_values_all_trials_stds,
    ref_line_buffer=0.05,
    xlim=(-2.25, 1.35),
    ylim=(-2.25, 1.35),
    xlabel=r"Actual $\delta C_9$", 
    ylabel=r"Predicted $\delta C_9$"
)

make_plot_note(ax, f"{n_trials} bootstrapped trials, {n_events_per_trial} events/trial", fontsize="large")

plt.show()