Training a neural network to predict $\delta C_9$

Setup

In [1]:
from pathlib import Path

import matplotlib.pyplot as plt

import torch
from torch import nn

from library.nn_training import select_device, train_and_eval
from library.datasets import Signal_Unbinned_Dataset, Bootstrapped_Signal_Unbinned_Dataset
from library.plotting import plot_loss_curves

device = select_device()

Device:  cuda


Define Helper Functions

Define Model

In [5]:
class Set_Trained_Regressor_NN(nn.Module):
    """
    Neural network trained on sets (that works event-by-event?).
    """
    def __init__(self, save_dir, nickname):
        super().__init__()

        self.nickname = nickname
        self.save_path = Path(save_dir).joinpath(f"{nickname}.pt")

        self.f = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        self.g = nn.Sequential(
            nn.Linear(32, 1)
        )
        
    def forward(self, x):
        sum_f = torch.sum(self.f(x), dim=1)
        g_sum_f = self.g(sum_f)
        return g_sum_f

    

Load / Generate Set Datasets

In [None]:
level = "gen"
save_dir = "../../state/new_physics/data/processed"
raw_signal_dir = "../../state/new_physics/data/raw/signal"

num_events_per_set = 5_000
raw_trials = {"train": range(0,5), "eval": range(5,10)}
num_sets_per_label = {"train": 12, "eval": 10}

train_dataset = Bootstrapped_Signal_Unbinned_Dataset(level=level, split="train", save_dir=save_dir)
eval_dataset = Bootstrapped_Signal_Unbinned_Dataset(level=level, split="eval", save_dir=save_dir)

# train_dataset.generate(raw_trials["train"], raw_signal_dir, num_events_per_set, num_sets_per_label["train"])
# eval_dataset.generate(raw_trials["eval"], raw_signal_dir, num_events_per_set, num_sets_per_label["eval"])

train_dataset.load()
eval_dataset.load()

In [11]:
train_dataset.features.shape

torch.Size([528, 5000, 4])

Train / Load Model

In [7]:
retrain = True
save_dir = "../../state/new_physics/models"
nickname = "test1"
model = Set_Trained_Regressor_NN(save_dir, nickname)

if retrain:
    
    learning_rate = 4e-4
    epochs = 10
    train_batch_size = 16
    eval_batch_size = 16
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    loss_table = train_and_eval(
        model, 
        train_dataset, eval_dataset, 
        loss_fn, 
        optimizer, 
        epochs, 
        train_batch_size, eval_batch_size, 
        device, 
        move_data=True)
    torch.save(model.state_dict(), model.save_path)
    _, ax = plt.subplots()
    plot_loss_curves(loss_table["epoch"], loss_table["train_loss"], loss_table["eval_loss"], ax)
    plt.show()

else:
    model.load_state_dict(torch.load(model.save_path, weights_only=True))
    model.to(device)

  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: Found dtype Double but expected Float

Evaluate Model

In [None]:
label = 32
log_p = softmax(model(eval_dataset.features[eval_dataset.labels==label].to(device)), dim=1)
for sset in log_p.detach().cpu():
    plt.plot(sset)
    n_bins = len(sset)
    bins = torch.arange(n_bins)
    weighted_avg = torch.sum(bins * sset) / torch.sum(sset)
    print("we", weighted_avg)
    print("argmax", torch.argmax(sset))
    print("sum p", torch.sum(sset))
    plt.scatter(weighted_avg.item(), 0)
    plt.scatter(torch.argmax(sset), torch.max(sset), zorder=10)

In [None]:
value_guesses = []
for label in range(0, 44):
    log_p = log_softmax(model(eval_dataset.features[eval_dataset.labels==label].to(device)), dim=1)
    bin_guess = torch.argmax(log_p, dim=1).to("cpu")
    value_guess = eval_dataset.bin_values[bin_guess]
    value_guesses.append(value_guess)

value_guesses = [np.expand_dims(guess, axis=0) for guess in value_guesses]
value_guesses = np.concat(value_guesses)
value_guess_stds = np.std(value_guesses, axis=1)
value_guess_means = np.mean(value_guesses, axis=1)
    

In [None]:
plot_prediction_linearity(
    eval_dataset.bin_values,
    value_guess_means,
    value_guess_stds,
    ref_line_buffer=0.05,
    xlim=(-2.25, 1.35),
    ylim=(-2.25, 1.35),
    xlabel=r"Actual $\delta C_9$", 
    ylabel=r"Predicted $\delta C_9$"
)
