Doing what Shawn did

Setup

In [None]:
import numpy as np

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable


from library.datasets import Signal_Images_Dataset
from library.nn_training import select_device, train_and_eval, Custom_Model
from library.plotting import plot_loss_curves, setup_high_quality_mpl_params, plot_prediction_linearity, make_plot_note

device = select_device()


Model definition

In [2]:
class Res_Block(nn.Module):
    def __init__(self, in_out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_channels=in_out_channels, out_channels=in_out_channels, kernel_size=3, stride=1, padding="same"),
            # nn.BatchNorm3d(num_features=in_out_channels),
            nn.ReLU(),
            nn.Conv3d(in_channels=in_out_channels, out_channels=in_out_channels, kernel_size=3, stride=1, padding="same"),
            # nn.BatchNorm3d(num_features=in_out_channels),
        )
        self.last_activation = nn.ReLU()
    def forward(self, x):
        x = self.block(x) + x
        x = self.last_activation(x)
        return x


class Conv_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block_a = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same"),
            # nn.BatchNorm3d(num_features=out_channels),
            nn.ReLU(),
            nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same"),
            # nn.BatchNorm3d(num_features=out_channels),
        )
        self.block_b = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same"),
            # nn.BatchNorm3d(num_features=out_channels),
        )
        self.last_activation = nn.ReLU()
    def forward(self, x):
        out_block_a = self.block_a(x)
        out_block_b = self.block_b(x)
        x = out_block_a + out_block_b
        x = self.last_activation(x)
        return x


class CNN_Res(Custom_Model):
    def __init__(self, nickname, model_dir,):
        super().__init__(nickname, model_dir)

        self.conv = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding="same", bias=False),
            # nn.BatchNorm3d(num_features=16),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=1, padding=1),
            *[Res_Block(in_out_channels=16) for _ in range(3)],
            Conv_Block(in_channels=16, out_channels=16),
            *[Res_Block(in_out_channels=16) for _ in range(3)],
            Conv_Block(in_channels=16, out_channels=16),
            *[Res_Block(in_out_channels=16) for _ in range(3)],
            # Conv_Block(in_channels=128, out_channels=128),
            # *[Res_Block(in_out_channels=128) for _ in range(1)],
        )

        self.dense = nn.Sequential(
            nn.Linear(in_features=16, out_features=32),
            nn.ReLU(),
            # nn.Dropout(0.5),
            nn.Linear(in_features=32, out_features=1),
        )
        
        self.double()

    def forward(self, x):
        x = self.conv(x)
        x = torch.mean(x, dim=(2,3,4))
        x = self.dense(x)
        x = torch.squeeze(x)
        return x

Data loading

In [None]:
regenerate = True

level = "gen"
save_dir = "../../state/new_physics/data/processed"

common_generate_kwargs = {
    "raw_signal_dir": "../../state/new_physics/data/raw/signal",
    "std_scale": True,
    "q_squared_veto": True,
    "balanced_classes": True,
    "num_events_per_set": 70_000,
    "num_sets_per_label": 50,
    "n_bins": 10,
}

datasets = {
    "train": Signal_Images_Dataset(level=level, split="train", save_dir=save_dir),
    "eval": Signal_Images_Dataset(level=level, split="eval", save_dir=save_dir),
}

if regenerate:
    datasets["train"].generate(
        raw_trials=range(1,21), 
        **common_generate_kwargs
    )
    datasets["eval"].generate(
        raw_trials=range(21,41), 
        **common_generate_kwargs
    )

datasets["train"].load(device)
datasets["eval"].load(device)

Image Visualization

In [None]:
def plot_volume_slices(arr, n_slices=3, cmap=plt.cm.magma, note=""):
    """
    Plot slices of volumetric data.
    Slices are along the z-axis (axis 2).
    Array arr should be a three-dimensional array.
    Slices might not be evenly spaced along z-axis.
    """

    fig = plt.figure()
    ax_3d = fig.add_subplot(projection="3d")

    var_dim = {
        0: "chi",
        1: "costheta_mu",
        2: "costheta_K",
    }

    dim_ind_cart = { # dont change for now
        "x": 1,     
        "y": 2,
        "z": 0,  
    }

    norm=Normalize(vmin=-1.1, vmax=1.1)
    arr = arr.squeeze()
    arr = arr.cpu()
    arr = np.transpose(
        arr, 
        (dim_ind_cart["x"], dim_ind_cart["y"], dim_ind_cart["z"])
    )
    colors = cmap(norm(arr))
    
    cart_dim_shape = {
        dim_name: arr.shape[dim_ind_cart[dim_name]] for dim_name in dim_ind_cart.keys()
    }

    def xy_plane(z_pos):
        x, y = np.indices(
            (cart_dim_shape["x"] + 1, cart_dim_shape["y"] + 1)
        )
        z = np.full(
            (cart_dim_shape["x"] + 1, cart_dim_shape["y"] + 1), z_pos
        )
        return x, y, z
    
    def plot_slice(z_index):
        x, y, z = xy_plane(z_index) 
        ax_3d.plot_surface(
            x, y, z, 
            rstride=1, cstride=1, 
            facecolors=colors[:,:,z_index], 
            shade=False,
        )

    def plot_outline(z_index, offset=0.3):
        x, y, z = xy_plane(z_index - offset)
        
        ax_3d.plot_surface(
            x, y, z, 
            rstride=1, cstride=1, 
            shade=False,
            color="#f2f2f2",
            edgecolor="#f2f2f2", 
        )

    z_indices = np.linspace(0, cart_dim_shape["z"]-1, n_slices, dtype=int) # forces integer indices
    for i in z_indices:
        plot_outline(i)
        plot_slice(i)

    cbar = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=ax_3d, location="left", shrink=0.5, pad=-0.05)
    cbar.set_label(r"${q^2}$ (Avg.)", size=11)

    ax_labels = {
        "chi": r"$\chi$", 
        "costheta_mu": r"$\cos\theta_\mu$",
        "costheta_K": r"$\cos\theta_K$",
    }

    ax_3d.set_xlabel(ax_labels[var_dim[dim_ind_cart["x"]]], labelpad=0)
    ax_3d.set_ylabel(ax_labels[var_dim[dim_ind_cart["y"]]], labelpad=0)
    # ax_3d.zaxis.set_rotate_label(False)
    ax_3d.set_zlabel(ax_labels[var_dim[dim_ind_cart["z"]]], labelpad=-3,)#rotation="horizontal") 

    ticks = {
        "costheta_mu": ["-1", "1"],
        "costheta_K": ["-1", "1"],
        "chi": ['0', r"$2\pi$"],
    }      

    ax_3d.set_xticks([0, arr.shape[dim_ind_cart["x"]]-1], ticks[var_dim[dim_ind_cart["x"]]])
    ax_3d.set_yticks([0, arr.shape[dim_ind_cart["y"]]-1], ticks[var_dim[dim_ind_cart["y"]]])
    ax_3d.set_zticks([0, arr.shape[dim_ind_cart["z"]]-1], ticks[var_dim[dim_ind_cart["z"]]])

    ax_3d.tick_params(pad=0.3)

    ax_3d.set_box_aspect(None, zoom=0.85)

    ax_3d.set_title(f"{note}", loc="center", y=1)


plot_volume_slices(
     datasets["train"].features[0], 
     n_slices=3, 
     note=r"$\delta C_9$ : "+f"{datasets["train"].labels[0]}"
)
# plt.savefig(f"{i}", bbox_inches="tight")
plt.show()
plt.close()

In [None]:
retrain = True

nickname = "cnn_res_with_checkpoints_6k"

model = CNN_Res(nickname, "../../state/new_physics/models")

if retrain:
    learning_rate = 4e-4
    epochs = 80
    train_batch_size = 32
    eval_batch_size = 32
    loss_fn = nn.L1Loss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    train_and_eval(
        model, 
        datasets["train"], datasets["eval"], 
        loss_fn, optimizer, 
        epochs, 
        train_batch_size, eval_batch_size, 
        device, 
        move_data=False,
        scheduler=ReduceLROnPlateau(optimizer, factor=0.9, patience=1),
        checkpoint_epochs=5,
    )
    _, ax = plt.subplots()
    plot_epoch_start = 0
    plot_loss_curves(
        model.loss_table["epoch"][plot_epoch_start:], 
        model.loss_table["train_loss"][plot_epoch_start:], 
        model.loss_table["eval_loss"][plot_epoch_start:], 
        ax
    )
    ax.set_yscale("log")
    plt.show()
else:
    model.load_final()
    # model.load_checkpoint(epoch_number=10)
    model.to(device)

Evaluate model (eval dataset)

In [None]:
model = CNN_Res(nickname, "../../state/new_physics/models")

for ep in range(0, 80, 5):
    model.load_checkpoint(epoch_number=ep)
    model.to(device)

    model.eval()
    with torch.no_grad():
        
        yhat = model(datasets["eval"].features)
        avgs = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
        stds = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).std(1).detach().cpu().numpy()

        y = datasets["eval"].labels
        unique_y = y.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
        unique_y

        fig, ax = plt.subplots()
        plot_prediction_linearity(ax, unique_y, avgs, stds)
        plt.show()
        plt.close()

    setup_high_quality_mpl_params()


print("final:")

model.load_final()
model.to(device)

model.eval()
with torch.no_grad():
        
    yhat = model(datasets["eval"].features)
    avgs = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
    stds = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).std(1).detach().cpu().numpy()

    y = datasets["eval"].labels
    unique_y = y.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
    unique_y

    fig, ax = plt.subplots()

plot_prediction_linearity(
    ax,
    unique_y,
    avgs,
    stds,
    ref_line_buffer=0.05,
    xlim=(-2.25, 1.35),
    ylim=(-2.25, 1.35),
    xlabel=r"Actual $\delta C_9$", 
    ylabel=r"Predicted $\delta C_9$"
)
make_plot_note(ax, f"Image (10 bin), Gen., {common_generate_kwargs["num_sets_per_label"]} boots., {common_generate_kwargs["num_events_per_set"]} events/boots.", fontsize="medium")

plt.show()
plt.close()

Evaluate model (train dataset)

In [None]:
model = CNN_Res(nickname, "../../state/new_physics/models")

for ep in range(0, 50, 5):
    model.load_checkpoint(epoch_number=ep)
    model.to(device)

    model.eval()
    with torch.no_grad():
        
        yhat = model(datasets["train"].features)
        avgs = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
        stds = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).std(1).detach().cpu().numpy()

        y = datasets["train"].labels
        unique_y = y.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
        unique_y

        fig, ax = plt.subplots()
        plot_prediction_linearity(ax, unique_y, avgs, stds)
        plt.show()
        plt.close()


print("final:")

setup_high_quality_mpl_params()

model = CNN_Res(nickname, "../../state/new_physics/models")

model.load_final()
model.to(device)

model.eval()
with torch.no_grad():
        
    yhat = model(datasets["train"].features)
    avgs = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
    stds = yhat.reshape(-1, common_generate_kwargs["num_sets_per_label"]).std(1).detach().cpu().numpy()

    y = datasets["train"].labels
    unique_y = y.reshape(-1, common_generate_kwargs["num_sets_per_label"]).mean(1).detach().cpu().numpy()
    unique_y

    fig, ax = plt.subplots()

plot_prediction_linearity(
    ax,
    unique_y,
    avgs,
    stds,
    ref_line_buffer=0.05,
    xlim=(-2.25, 1.35),
    ylim=(-2.25, 1.35),
    xlabel=r"Actual $\delta C_9$", 
    ylabel=r"Predicted $\delta C_9$"
)
make_plot_note(ax, f"Image (10 bin), Gen., {common_generate_kwargs["num_sets_per_label"]} boots., {common_generate_kwargs["num_events_per_set"]} events/boots.", fontsize="medium")

plt.show()
plt.close()