The full thing!

# Setup

## Imports

In [None]:
import numpy as np

import scipy.special

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import matplotlib as mpl
import matplotlib.pyplot as plt

from library.datasets import (
    Signal_Images_Dataset, 
    Bootstrapped_Unbinned_Signal_Dataset, 
    Binned_Signal_Dataset
)
from library.models import CNN_Res, Deep_Sets, Event_By_Event_NN
from library.nn_training import select_device, train_and_eval
from library.plotting import (
    plot_loss_curves, 
    setup_high_quality_mpl_params, 
    plot_prediction_linearity, 
    make_plot_note, 
    plot_volume_slices
)
from library.util import bootstrap_labeled_sets, get_num_per_unique_label


## Select device (cuda if available)

In [None]:
device = select_device()

## Setup fancy plotting

In [None]:
setup_high_quality_mpl_params()

## Setup global parameters

In [None]:
dataset_save_dir = "../../state/new_physics/data/processed"
raw_signal_dir = "../../state/new_physics/data/raw/signal"
model_dir = "../../state/new_physics/models"

std_scale = True
q_squared_veto = True
balanced_classes = True

set_sizes = [70_000, 24_000, 6_000]

new_physics_delta_c9_value = -0.82

# Generator Level

In [None]:
level = "gen"

## Shawn's Method

In [4]:
num_image_bins = 10

### Setup datasets

In [None]:
train_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="train", 
        save_dir=dataset_save_dir,
        extra_description=num_events_per_set,
    ) 
    for num_events_per_set in set_sizes
}

eval_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="eval", 
        save_dir=dataset_save_dir,
        extra_description=num_events_per_set,
    ) 
    for num_events_per_set in set_sizes
}

single_label_eval_datasets = {
    num_events_per_set : Signal_Images_Dataset(
        level=level, 
        split="eval", 
        save_dir=dataset_save_dir,
        extra_description=f"{num_events_per_set}_single",
    ) 
    for num_events_per_set in set_sizes
}

### Generate datasets

In [None]:
for num_events_per_set in set_sizes:

    train_datasets[num_events_per_set].generate(
        raw_trials=range(1,21), 
        raw_signal_dir=raw_signal_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=50,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
    )

    eval_datasets[num_events_per_set].generate(
        raw_trials=range(21,41), 
        raw_signal_dir=raw_signal_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=50,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
    )

    single_label_eval_datasets[num_events_per_set].generate(
        raw_trials=range(21,41), 
        raw_signal_dir=raw_signal_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=2000,
        n_bins=num_image_bins,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        labels_to_sample=[new_physics_delta_c9_value],
    )

### Peek at features

In [None]:
num_events_per_set = 24_000

dset = train_datasets[num_events_per_set]
dset.load()

plot_volume_slices(
     dset.features[0], 
     n_slices=3, 
     note=r"$\delta C_9$ : "+f"{dset.labels[0]}"
)
plt.show()
plt.close()

### Setup Models

In [None]:
models = {
    num_events_per_set : CNN_Res(
        model_dir, 
        extra_description=f"v2_{num_events_per_set}"
    )
    for num_events_per_set in set_sizes
}

### Model Training

In [None]:
learning_rate = 4e-4
epochs = 80
train_batch_size = 32
eval_batch_size = 32

for num_events_per_set in set_sizes:

    model = models[num_events_per_set]

    loss_fn = nn.L1Loss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    train_dataset = train_datasets[num_events_per_set]
    eval_dataset = eval_datasets[num_events_per_set]
    train_dataset.load()
    eval_dataset.load()

    train_and_eval(
        model, 
        train_dataset, 
        eval_dataset, 
        loss_fn, 
        optimizer, 
        epochs, 
        train_batch_size, 
        eval_batch_size, 
        device, 
        move_data=True,
        scheduler=ReduceLROnPlateau(
            optimizer, 
            factor=0.9, 
            patience=1
        ),
        checkpoint_epochs=5,
    )

    _, ax = plt.subplots()
    plot_loss_curves(
        model.loss_table,
        ax,
        start_epoch=0,
        log_scale=True
    )
    plt.show()
    plt.close()

### Linearity

In [None]:
for num_events_per_set in set_sizes:
    
    model = models[num_events_per_set]
    model.load_final()
    model.to(device)
    model.eval()

    eval_dataset = eval_datasets[num_events_per_set]
    eval_dataset.load()

    with torch.no_grad():

        predictions = []
        for set_features in eval_dataset.features:
            prediction = model(set_features.unsqueeze(0).to(device))
        predictions.append(prediction)
        predictions = torch.tensor(predictions)

        num_sets_per_label = get_num_per_unique_label(eval_dataset.labels)
        # DANGER: Below assumes data sorted by labels!
        avg_yhat_per_label = predictions.reshape(-1, num_sets_per_label).mean(dim=1).detach().cpu().numpy()
        std_yhat_per_label = predictions.reshape(-1, num_sets_per_label).std(dim=1).detach().cpu().numpy()
        unique_labels = torch.unique(eval_dataset.labels)

        mse = torch.nn.functional.mse_loss(predictions, eval_dataset.labels)
        mae = torch.nn.functional.l1_loss(predictions, eval_dataset.labels)

    print("mse:", mse)
    print("mae:", mae)

    fig, ax = plt.subplots()
    plot_prediction_linearity(
        ax,
        unique_labels,
        avg_yhat_per_label,
        std_yhat_per_label,
    )
    make_plot_note(
        ax, 
        (
            f"Images ({num_image_bins} bins), {level}., "
            + f"{num_sets_per_label} boots., "
            + f"{num_events_per_set} events/boots."
        ), 
    )
    plt.show()
    plt.close()

### Sensitivity

In [None]:
for num_events_per_set in set_sizes:
    
    eval_dataset = single_label_eval_datasets[num_events_per_set]
    eval_dataset.load()

    model = models[num_events_per_set]
    model.load_final()
    model.to(device)
    model.eval()

    with torch.no_grad():
        predictions = []
        for set_features in eval_dataset.features:
            prediction = model(set_features.unsqueeze(0).to(device))
            predictions.append(prediction)
        predictions = torch.tensor(predictions)
        mean = predictions.mean()
        std = predictions.std()
        bias = mean - new_physics_delta_c9_value

    print("bias: ", bias.round(3))
    print("std: ", std.round(3))

    fig, ax = plt.subplots()

    bins = 50
    xbounds = (-1.5, 0)
    ybounds = (0, 200)
    std_marker_height = 20

    ax.hist(predictions, bins=bins, range=xbounds)
    ax.vlines(
        new_physics_delta_c9_value, 
        0, 
        ybounds[1], 
        color="red", 
        label=f"Target ({new_physics_delta_c9_value})"
    )
    ax.vlines(
        mean, 
        0, 
        ybounds[1], 
        color="red", 
        linestyles="--", 
        label=r"$\mu = $ "+f"{mean.round(3)}"
    )
    ax.hlines(
        std_marker_height, 
        mean, 
        mean+std, 
        color="orange", 
        linestyles="dashdot", 
        label=r"$\sigma = $ "+f"{std.round(3)}"
    )
    
    ax.set_xlabel(r"Predicted $\delta C_9$")
    
    ax.set_xbound(*xbounds)
    ax.set_ybound(*ybounds)
    
    ax.legend()
    make_plot_note(
        ax, 
        (
            f"Images ({num_image_bins} bins), {level}.," 
            + f"{num_sets_per_label} boots.," 
            + f"{num_events_per_set} events/boots."
        ), 
        fontsize="medium"
    )
    plt.show()
    plt.close()

NameError: name 'single_label_eval_datasets' is not defined

## Deep Sets

### Model Training

#### Generate Datasets

In [None]:
level = "gen"
save_dir = "../../state/new_physics/data/processed"
raw_signal_dir = "../../state/new_physics/data/raw/signal"
std_scale = True
q_squared_veto = True
balanced_classes = True
num_sets_per_label = 50

for num_events_per_set in [70_000, 24_000, 6_000]:

    name = f"unbinned_sets_{num_events_per_set}"

    datasets = {
        "train": Bootstrapped_Unbinned_Signal_Dataset(name=name, level=level, split="train", save_dir=save_dir),
        "eval": Bootstrapped_Unbinned_Signal_Dataset(name=name, level=level, split="eval", save_dir=save_dir),
    }

    datasets["train"].generate(
        raw_trials=range(1,21), 
        raw_signal_dir=raw_signal_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
    )
    datasets["eval"].generate(
        raw_trials=range(21,41), 
        raw_signal_dir=raw_signal_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
    )


#### Training

In [None]:
learning_rate = 4e-4
epochs = 80
train_batch_size = 32
eval_batch_size = 32

for num_events_per_set in [70_000, 24_000, 6_000]:

    nickname = f"deep_sets_{num_events_per_set}"
    print("Training: ", nickname)
    model = Deep_Sets(nickname, "../../state/new_physics/models")

    loss_fn = nn.L1Loss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    train_dataset = Bootstrapped_Unbinned_Signal_Dataset(
        name=f"unbinned_sets_{num_events_per_set}", 
        level="gen", split="train", 
        save_dir="../../state/new_physics/data/processed"
    )
    eval_dataset = Bootstrapped_Unbinned_Signal_Dataset(
        name=f"unbinned_sets_{num_events_per_set}", 
        level="gen", split="eval", 
        save_dir="../../state/new_physics/data/processed"
    )
    train_dataset.load()
    eval_dataset.load()
    train_dataset.to(device)
    eval_dataset.to(device)

    train_and_eval(
        model, 
        train_dataset, eval_dataset, 
        loss_fn, optimizer, 
        epochs, 
        train_batch_size, eval_batch_size, 
        device, 
        move_data=False,
        scheduler=ReduceLROnPlateau(optimizer, factor=0.9, patience=1),
        checkpoint_epochs=5,
    )

    _, ax = plt.subplots()
    plot_epoch_start = 0
    plot_loss_curves(
        model.loss_table["epoch"][plot_epoch_start:], 
        model.loss_table["train_loss"][plot_epoch_start:], 
        model.loss_table["eval_loss"][plot_epoch_start:], 
        ax
    )
    ax.set_yscale("log")
    plt.show()
    plt.close()

### Evaluation

#### Linearity

In [None]:
num_sets_per_label = 50

for num_events_per_set in [70_000, 24_000, 6_000]:
    
    nickname = f"deep_sets_{num_events_per_set}"
    model = Deep_Sets(nickname, "../../state/new_physics/models")
    model.load_final()
    model.to(device)
    model.eval()
    
    eval_dataset = Bootstrapped_Unbinned_Signal_Dataset(
        name=f"unbinned_sets_{num_events_per_set}", 
        level="gen", split="eval", 
        save_dir="../../state/new_physics/data/processed"
    )
    eval_dataset.load()
    eval_dataset.to(device)

    with torch.no_grad():
        
        predictions = []
        for set_features in eval_dataset.features:
            set_features = set_features.unsqueeze(0)
            prediction = model(set_features)
            predictions.append(prediction)
        predictions = torch.tensor(predictions)

        avgs = predictions.reshape(-1, num_sets_per_label).mean(1).detach().cpu().numpy()
        stds = predictions.reshape(-1, num_sets_per_label).std(1).detach().cpu().numpy()

        ys = eval_dataset.labels
        unique_y = ys.reshape(-1, num_sets_per_label).mean(1).detach().cpu().numpy()

        def calc_mse():
            dif = predictions.detach().cpu().numpy() - ys.detach().cpu().numpy()
            dif_sq = dif**2
            out = dif_sq.mean()
            return out
        mse = calc_mse()

        def calc_mae():
            dif = predictions.detach().cpu().numpy() - ys.detach().cpu().numpy()
            abs_dif = np.abs(dif)
            out = abs_dif.mean()
            return out
        mae = calc_mae()
        
    print("mse:", mse)
    print("mae:", mae)

    fig, ax = plt.subplots()

    plot_prediction_linearity(
        ax,
        unique_y,
        avgs,
        stds,
        ref_line_buffer=0.05,
        xlim=(-2.25, 1.35),
        ylim=(-2.25, 1.35),
        xlabel=r"Actual $\delta C_9$", 
        ylabel=r"Predicted $\delta C_9$"
    )
    make_plot_note(
        ax, 
        (
            "Deep Sets, Gen., "
            + f"{num_sets_per_label} boots., "
            + f"{num_events_per_set} events/boots."
        ), 
        fontsize="medium"
    )

    plt.show()
    plt.close()

#### Sensitivity

##### Generate single event dataset

In [None]:
raw_signal_dir = "../../state/new_physics/data/raw/signal"
std_scale = True
q_squared_veto = True
balanced_classes = True
num_sets_per_label = 2_000

for num_events_per_set in [70_000, 24_000, 6_000]:
    single_label_eval_dataset = Bootstrapped_Unbinned_Signal_Dataset(
            name=f"unbinned_sets_{num_events_per_set}_at_-0.82", 
            level="gen", split="eval", 
            save_dir="../../state/new_physics/data/processed"
    )
    single_label_eval_dataset.generate(
        raw_trials=range(21,41), 
        raw_signal_dir=raw_signal_dir,
        num_events_per_set=num_events_per_set,
        num_sets_per_label=num_sets_per_label,
        q_squared_veto=q_squared_veto,
        std_scale=std_scale,
        balanced_classes=balanced_classes,
        labels_to_sample=[-0.82],
    )


##### Run sensitivity test

In [None]:
num_sets_per_label = 2000

for num_events_per_set in [70_000, 24_000, 6_000]:
    single_label_eval_dataset = Bootstrapped_Unbinned_Signal_Dataset(
            name=f"unbinned_sets_{num_events_per_set}_at_-0.82", 
            level="gen", split="eval", 
            save_dir="../../state/new_physics/data/processed"
    )
    single_label_eval_dataset.load()
    single_label_eval_dataset.to(device)

    nickname = f"deep_sets_{num_events_per_set}"
    model = Deep_Sets(nickname, "../../state/new_physics/models")
    model.load_final()
    model.to(device)
    model.eval()

    with torch.no_grad():
        predictions = []
        for set_features in single_label_eval_dataset.features:
            prediction = model(set_features.unsqueeze(0))
            predictions.append(prediction.detach().cpu())
        predictions = np.array(predictions)
        mean = predictions.mean()
        std = predictions.std()
        bias = mean - -0.82

    print("bias: ", bias.round(3))
    print("std: ", std.round(3))

    fig, ax = plt.subplots()

    xbounds = (-1.5, 0)
    ybounds = (0, 200)

    ax.hist(predictions, bins=50, range=xbounds)
    ax.vlines(-0.82, 0, ybounds[1], color="red", label="Target (-0.82)")
    ax.vlines(mean, 0, ybounds[1], color="red", linestyles="--", label=r"$\mu = $ "+f"{mean.round(3)}")
    ax.hlines(20, mean, mean+std, color="orange", linestyles="dashdot", label=r"$\sigma = $ "+f"{std.round(3)}")
    ax.set_xlabel(r"Predicted $\delta C_9$")
    make_plot_note(ax, f"Deep Sets, Gen., {num_sets_per_label} boots., {num_events_per_set} events/boots.", fontsize="medium")
    ax.set_xbound(*xbounds)
    ax.set_ybound(*ybounds)
    ax.legend()
    plt.show()
    plt.close()

## Event-by-event Method

### Model Training

#### Generate Datasets

In [None]:
level = "gen"
save_dir = "../../state/new_physics/data/processed"
raw_signal_dir = "../../state/new_physics/data/raw/signal"

q_squared_veto = True
std_scale = True
balanced_classes = True

name = "binned_signal"

datasets = {
    "train": Binned_Signal_Dataset(name, level=level, split="train", save_dir=save_dir),
    "eval": Binned_Signal_Dataset(name, level=level, split="eval", save_dir=save_dir),
}

datasets["train"].generate(
    raw_trials=range(1,21), 
    raw_signal_dir=raw_signal_dir, 
    q_squared_veto=q_squared_veto,
    std_scale=std_scale, 
    balanced_classes=balanced_classes
)
datasets["eval"].generate(
    raw_trials=range(21,41), 
    raw_signal_dir=raw_signal_dir, 
    q_squared_veto=q_squared_veto,
    std_scale=std_scale, 
    balanced_classes=balanced_classes
)

#### Training

In [None]:
model = Event_By_Event_NN("event_by_event_nn", "../../state/new_physics/models")

learning_rate = 3e-3
epochs = 200
train_batch_size = 10_000
eval_batch_size = 10_000
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

dataset_name = "binned_signal"
training_dataset = Binned_Signal_Dataset(dataset_name, level="gen", split="train", save_dir=save_dir)
eval_dataset = Binned_Signal_Dataset(dataset_name, level="gen", split="eval", save_dir=save_dir)
training_dataset.load()
eval_dataset.load()

loss_table = train_and_eval(
    model, 
    training_dataset, eval_dataset, 
    loss_fn, optimizer, 
    epochs, 
    train_batch_size, eval_batch_size, 
    device, 
    move_data=True,
    scheduler= ReduceLROnPlateau(optimizer, factor=0.95, threshold=0, patience=0, eps=1e-9),
    checkpoint_epochs=5,
)
_, ax = plt.subplots()
plot_epoch_start = 0
plot_loss_curves(loss_table["epoch"][plot_epoch_start:], loss_table["train_loss"][plot_epoch_start:], loss_table["eval_loss"][plot_epoch_start:], ax)
ax.set_yscale("log")
plt.show()

### Evaluation

#### Linearity

In [None]:
# def calc_mse():
#     dif = expected_values_all_trials - datasets["eval"].bin_values
#     dif_sq = dif**2
#     out = dif_sq.mean()
#     return out

# def calc_mae():
#     dif = expected_values_all_trials - datasets["eval"].bin_values
#     abs_dif = np.abs(dif)
#     out = abs_dif.mean()
#     return out

dataset_name = "binned_signal"
eval_dataset = Binned_Signal_Dataset(dataset_name, level="gen", split="eval", save_dir="../../state/new_physics/data/processed")
eval_dataset.load()

model = Event_By_Event_NN("event_by_event_nn", "../../state/new_physics/models")
model.load_final()
model.to(device)
model.eval()

num_sets_per_label = 50

for num_events_per_set in [70_000, 24_000, 6_000]:

    boot_x, boot_y_bin_indices = bootstrap_labeled_sets(
        eval_dataset.features, 
        eval_dataset.labels, 
        num_events_per_set, num_sets_per_label,
        reduce_labels=True,
    )
    boot_y = eval_dataset.bin_values[boot_y_bin_indices]

    predictions = []
    for set_features in boot_x:
        set_features = set_features.to(device)
        bin_values = eval_dataset.bin_values.to(device)
        expected_value = model.calculate_expected_value(set_features, bin_values)
        predictions.append(expected_value)
    predictions = torch.tensor(predictions)
    assert predictions.shape == boot_y.shape

    with torch.no_grad():
        mse = torch.mean((predictions - boot_y)**2)
        mae = torch.mean(torch.abs(predictions - boot_y))
        yhat_avgs = predictions.reshape(-1, num_sets_per_label).mean(1)
        yhat_stds = predictions.reshape(-1, num_sets_per_label).std(1)
    print("mse:", mse)
    print("mae:", mae)

    fig, ax = plt.subplots()
    plot_prediction_linearity(
        ax,
        eval_dataset.bin_values.numpy(),
        yhat_avgs,
        yhat_stds,
        ref_line_buffer=0.05,
        xlim=(-2.25, 1.35),
        ylim=(-2.25, 1.35),
        xlabel=r"Actual $\delta C_9$", 
        ylabel=r"Predicted $\delta C_9$"
    )
    make_plot_note(ax, f"Event-by-event, Gen., {num_sets_per_label} boots., {num_events_per_set} events/boots.", fontsize="medium")
    plt.show()
    plt.close()

In [None]:
n_events_per_set = 6_000
n_sets = 2000

target_bin_index = np.argwhere(datasets["eval"].bin_values==-0.82).item()
rng = np.random.default_rng()

pool_x = datasets["eval"].features[datasets["eval"].labels==target_bin_index]
pool_y = datasets["eval"].labels[datasets["eval"].labels==target_bin_index]

predictions = []
for _ in range(n_sets):
    selection_indices = rng.choice(len(pool_x), n_events_per_set)
    boot_x = pool_x[selection_indices]
    boot_y = pool_y[selection_indices]

    log_probs = predict_log_probabilities(boot_x, model)
    bin_values = datasets["eval"].bin_values + 5
    log_bin_values = np.log(bin_values)
    lse = logsumexp(log_bin_values + log_probs.cpu().numpy())
    prediction = np.exp(lse) - 5
    predictions.append(prediction)

predictions = np.array(predictions)

mean= predictions.mean()
std = predictions.std()

setup_high_quality_mpl_params()
fig, ax = plt.subplots()

xbounds = (-1.5, 0)
ybounds = (0, 200)

ax.hist(predictions, bins=50, range=xbounds)
ax.vlines(-0.82, 0, ybounds[1], color="red", label="Target (-0.82)")
ax.vlines(mean, 0, ybounds[1], color="red", linestyles="--", label=r"$\mu = $ "+f"{mean.round(3)}")
ax.hlines(20, mean, mean+std, color="orange", linestyles="dashdot", label=r"$\sigma = $ "+f"{std.round(3)}")
ax.set_xlabel(r"Predicted $\delta C_9$")
make_plot_note(ax, f"Event-by-event, Gen., {n_sets} boots., {n_events_per_set} events/boots.", fontsize="medium")
ax.set_xbound(*xbounds)
ax.set_ybound(*ybounds)
ax.legend()


plt.show()

# Detector Level

## Shawn's Method

### Datasets

### Model Training

### Evaluation

## Deep Sets

### Datasets

### Model Training

### Evaluation

## Event-by-event Method

### Datasets

### Model Training

### Evaluation

# Detector Level with Backgrounds

## Shawn's Method

## Deep Sets

## Event-by-event Method

In [None]:




def predict_log_probabilities_by_label(x, y, model):
    """
    Predict the log probability of each class for each subset of same labeled events.
    
    x : A torch tensor of features of events (from multiple labels).
    y : A torch tensor of event labels.
    """
    labels = np.unique(y.cpu())
    log_probabilities = []
    for label in labels:
        features_label = x[y==label]
        log_probabilities_label = predict_log_probabilities(features_label, model).cpu().numpy()
        log_probabilities.append(np.expand_dims(log_probabilities_label, axis=0))
    log_probabilities = np.concatenate(log_probabilities, axis=0)
    assert log_probabilities.shape[0] == len(labels)
    return log_probabilities


def calculate_predicted_expected_value_by_label(predictions, bin_values):
    """
    Calculate the predicted expected binned value for each subset of same labeled events, given 
    the predicted probability distribution for each subset.

    predictions : numpy array of predicted probability distributions for each label.
    bin_values : numpy array of the value corresponding to each bin.
    """
    bin_value_shift = np.abs(np.min(bin_values)) + 1
    shifted_bin_values = bin_values + bin_value_shift
    log_shifted_bin_values = np.tile(np.log(shifted_bin_values), (len(shifted_bin_values), 1))
    log_shifted_expected_values = scipy.special.logsumexp(predictions + log_shifted_bin_values, axis=1)
    shifted_expected_values = np.exp(log_shifted_expected_values)
    expected_values = shifted_expected_values - bin_value_shift
    return expected_values


def plot_log_probabilities_over_labels(fig, ax, predictions, bin_values, cmap=plt.cm.viridis):
    """
    Plot the predicted log probability of each class for each subset of same labeled events.

    predictions : A numpy array of set probabilities (rows correspond to labels, columns correspond to class predictions).
    bin_values : A numpy array of the value each bin represents. 
    """

    color_bounds = np.append(bin_values, bin_values[-1] + (bin_values[-1] - bin_values[-2]))
    color_norm = mpl.colors.BoundaryNorm(color_bounds, cmap.N)

    for value, curve in zip(bin_values, predictions):
        ax.plot(bin_values, curve, color=cmap(color_norm(value)))

    fig.colorbar(mpl.cm.ScalarMappable(norm=color_norm, cmap=cmap), ax=ax, label=r"Actual $\delta C_9$")
    ax.set_xlabel(r"$\delta C_9$")
    ax.set_ylabel(r"$\log p(\delta C_9 | x_1, ..., x_N)$")


def plot_expected_value_over_labels(ax, expected_values, bin_values):
    """
    Plot the predicted expected value for each label.
    """
    ax.scatter(bin_values, expected_values, label="Prediction", color="firebrick", s=16, zorder=5)
    ax.plot(
        bin_values, bin_values,
        label="Ref. Line (Slope = 1)",
        color="grey",
        linewidth=0.5,
        zorder=0
    )
    ax.set_xlabel(r"Actual $\delta C_9$")
    ax.set_ylabel(r"Predicted $\delta C_9$")
    ax.legend()
