# CNN for AdvSND target energy reconstruction

In [None]:
import numpy as np

In [None]:
import pandas as pd

In [None]:
import tensorflow as tf

In [None]:
import keras_tuner as kt

In [None]:
from preprocessing import reshape_data

In [None]:
import uproot

In [None]:
import scipy

In [None]:
import hist

In [None]:
import plotting

In [None]:
from iminuit import cost
from iminuit import Minuit

In [None]:
import matplotlib.pyplot as plt

# plt.style.use(["science", "notebook"])

In [None]:
plt.rcParams["font.size"] = 14
plt.rcParams["axes.formatter.limits"] = -5, 4
plt.rcParams["figure.figsize"] = 6, 4
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
filename_train = "dataframe_CC_saturation5_train.root:df"
filename_test = "dataframe_CC_saturation5_test.root:df"

In [None]:
events_train = uproot.open(filename_train)
events_test = uproot.open(filename_test)

In [None]:
n_events = events_train.num_entries + events_test.num_entries

In [None]:
target = "lepton_energy"

target_pretty = "Lepton Energy"
target_LaTeX = "E_\ell"

In [None]:
target = "hadron_energy"

target_pretty = "Hadron Energy"
target_LaTeX = "E_h"

In [None]:
target = "nu_energy"

target_pretty = "Neutrino energy"
target_LaTeX = "E_\nu"

In [None]:
target = "start_z"

target_pretty = "Start Z"

In [None]:
# target = "both"

In [None]:
# target = "deps"
# edep_correction = 1e-9

In [None]:
def event_generator(train=True):
    events = events_train if train else events_test
    log = "energy" in target
    for batch, report in events.iterate(step_size=1, report=True, library="np"):
        for i in range(batch["X"].shape[0]):
            yield (
                batch["X"].astype(np.float16)[i],
                batch["X_mufilter"].astype(np.float16)[i],
                (np.log(batch[target][i]) if log else batch[target][i]),
            )

In [None]:
gen = event_generator(True)

In [None]:
input_shape = (100, 3072, 1)

In [None]:
sample = gen.__next__()

In [None]:
plt.figure()
plt.imshow(sample[0], aspect=0.05)
plt.figure()
plt.imshow(sample[1], aspect=0.01)

In [None]:
sample[2]

In [None]:
generator_spec_0 = tf.type_spec_from_value(gen.__next__()[0])
generator_spec_1 = tf.type_spec_from_value(gen.__next__()[1])
generator_spec_2 = tf.type_spec_from_value(gen.__next__()[2])

In [None]:
# TODO reshape data only once

In [None]:
ds_train = (
    tf.data.Dataset.from_generator(
        event_generator,
        output_signature=(
            generator_spec_0,
            generator_spec_1,
            generator_spec_2,
        ),
    )
    .map(reshape_data)
    .apply(tf.data.experimental.assert_cardinality(events_train.num_entries))
)

In [None]:
ds_test = (
    tf.data.Dataset.from_generator(
        event_generator,
        args=[False],
        output_signature=(
            generator_spec_0,
            generator_spec_1,
            generator_spec_2,
        ),
    )
    .map(reshape_data)
    .apply(tf.data.experimental.assert_cardinality(events_test.num_entries))
)

In [None]:
# y_test = events_test["energy_dep_target"].array() + edep_correction, events_test["energy_dep_mufilter"].array()+edep_correction
y_test = (
    np.log(events_test[target].array())
    if "energy" in target
    else events_test[target].array()
)

In [None]:
batch_size = 30

In [None]:
batched_ds_train = ds_train.batch(batch_size)

In [None]:
batched_ds_test = ds_test.batch(batch_size)

In [None]:
train_fraction = 0.8
val_fraction = 0.2

total_batches = tf.data.experimental.cardinality(batched_ds_train).numpy()
train_size = int(total_batches * train_fraction)
val_size = total_batches - train_size

train_subset = batched_ds_train.take(train_size)
val_subset = batched_ds_train.skip(train_size)

# train_subset = train_subset.shuffle(buffer_size=1000)

In [None]:
import tensorflow.keras
from tensorflow.keras.layers import (
    BatchNormalization,
    Concatenate,
    Conv2D,
    Dense,
    Dropout,
    Flatten,
    GlobalAveragePooling2D,
    Input,
    MaxPooling2D,
    Multiply,
    ReLU,
    Reshape,
)
from tensorflow.keras import Input, layers
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
import tensorflow.keras.optimizers
import tensorflow.keras.metrics
import tensorflow.keras.losses

In [None]:
from CBAM3D import CBAM

In [None]:
import tensorflow.keras.backend as K

K.set_image_data_format("channels_last")

In [None]:
history_df = None

In [None]:
model_name = f"CNN_3dSat5_grandjorasses_{target}"

In [None]:
def build_model(hp):
    filters1 = hp.Int("filters1", min_value=16, max_value=64, step=16)
    kernel_size1 = hp.Int("kernel_size1", min_value=1, max_value=9, step=1)

    filters2 = hp.Int("filters2", min_value=16, max_value=64, step=16)
    kernel_size2 = hp.Int("kernel_size2", min_value=1, max_value=9, step=1)

    filters3 = hp.Int("filters3", min_value=16, max_value=64, step=16)
    kernel_size3 = hp.Int("kernel_size3", min_value=1, max_value=9, step=1)

    filters4 = hp.Int("filters4", min_value=16, max_value=64, step=16)
    kernel_size4 = hp.Int("kernel_size4", min_value=1, max_value=9, step=1)

    drop_rate = hp.Float("drop_rate", min_value=0.1, max_value=0.6, step=0.1)

    # Target Horizontal Branch
    target_h_branch = Sequential(name="target_h_branch")
    target_h_branch.add(Input(shape=input_shape, name="target_h_in"))

    target_h_branch.add(
        Conv2D(
            filters=filters1, kernel_size=kernel_size1, padding="same", name="conv_h1"
        )
    )
    target_h_branch.add(BatchNormalization(name="batch_norm_h1"))
    target_h_branch.add(ReLU())
    target_h_branch.add(CBAM(name="CBAM_h1"))
    target_h_branch.add(MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_h1"))
    target_h_branch.add(Dropout(rate=drop_rate))

    target_h_branch.add(
        Conv2D(
            filters=filters2, kernel_size=kernel_size2, padding="same", name="conv_h2"
        )
    )
    target_h_branch.add(BatchNormalization(name="batch_norm_h2"))
    target_h_branch.add(ReLU())
    target_h_branch.add(CBAM(name="CBAM_h2"))
    target_h_branch.add(MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_h2"))
    target_h_branch.add(Dropout(rate=drop_rate))

    target_h_branch.add(
        Conv2D(
            filters=filters3, kernel_size=kernel_size3, padding="same", name="conv_h3"
        )
    )
    target_h_branch.add(BatchNormalization(name="batch_norm_h3"))
    target_h_branch.add(ReLU())
    target_h_branch.add(CBAM(name="CBAM_h3"))
    target_h_branch.add(MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_h3"))
    target_h_branch.add(Dropout(rate=drop_rate))

    target_h_branch.add(
        Conv2D(
            filters=filters4, kernel_size=kernel_size4, padding="same", name="conv_h4"
        )
    )
    target_h_branch.add(BatchNormalization(name="batch_norm_h4"))
    target_h_branch.add(ReLU())
    target_h_branch.add(CBAM(name="CBAM_h4"))
    target_h_branch.add(MaxPooling2D(pool_size=(2, 2), padding="same", name="pool_h4"))

    target_h_branch.add(Flatten(name="flatten_h"))

    # Target Vertical Branch
    target_v_branch = Sequential(name="target_v_branch")
    target_v_branch.add(Input(shape=input_shape, name="target_v_in"))

    target_v_branch.add(
        Conv2D(
            filters=filters1, kernel_size=kernel_size1, padding="same", name="conv_v1"
        )
    )
    target_v_branch.add(BatchNormalization(name="batch_norm_v1"))
    target_v_branch.add(ReLU())
    target_v_branch.add(CBAM(name="CBAM_v1"))
    target_v_branch.add(MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_v1"))
    target_v_branch.add(Dropout(rate=drop_rate))

    target_v_branch.add(
        Conv2D(
            filters=filters2, kernel_size=kernel_size2, padding="same", name="conv_v2"
        )
    )
    target_v_branch.add(BatchNormalization(name="batch_norm_v2"))
    target_v_branch.add(ReLU())
    target_v_branch.add(CBAM(name="CBAM_v2"))
    target_v_branch.add(MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_v2"))
    target_v_branch.add(Dropout(rate=drop_rate))

    target_v_branch.add(
        Conv2D(
            filters=filters3, kernel_size=kernel_size3, padding="same", name="conv_v3"
        )
    )
    target_v_branch.add(BatchNormalization(name="batch_norm_v3"))
    target_v_branch.add(ReLU())
    target_v_branch.add(CBAM(name="CBAM_v3"))
    target_v_branch.add(MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_v3"))
    target_v_branch.add(Dropout(rate=drop_rate))

    target_v_branch.add(
        Conv2D(
            filters=filters4, kernel_size=kernel_size4, padding="same", name="conv_v4"
        )
    )
    target_v_branch.add(BatchNormalization(name="batch_norm_v4"))
    target_v_branch.add(ReLU())
    target_v_branch.add(CBAM(name="CBAM_v4"))
    target_v_branch.add(MaxPooling2D(pool_size=(2, 2), padding="same", name="pool_v4"))

    target_v_branch.add(Flatten(name="flatten_v"))

    # MU Filter Horizontal Branch
    mufilter_h_branch = Sequential(name="mufilter_h_branch")
    mufilter_h_branch.add(Input(shape=(21, 4608, 1), name="mufilter_h_in"))

    mufilter_h_branch.add(
        Conv2D(
            filters=filters1, kernel_size=kernel_size1, padding="same", name="conv_h1_1"
        )
    )
    mufilter_h_branch.add(BatchNormalization(name="batch_norm_h1_1"))
    mufilter_h_branch.add(ReLU())
    mufilter_h_branch.add(CBAM(name="CBAM_h1_1"))
    mufilter_h_branch.add(
        MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_h1_1")
    )
    mufilter_h_branch.add(Dropout(rate=drop_rate))

    mufilter_h_branch.add(
        Conv2D(
            filters=filters2, kernel_size=kernel_size2, padding="same", name="conv_h2_1"
        )
    )
    mufilter_h_branch.add(BatchNormalization(name="batch_norm_h2_1"))
    mufilter_h_branch.add(ReLU())
    mufilter_h_branch.add(CBAM(name="CBAM_h2_1"))
    mufilter_h_branch.add(
        MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_h2_1")
    )
    mufilter_h_branch.add(Dropout(rate=drop_rate))

    mufilter_h_branch.add(
        Conv2D(
            filters=filters3, kernel_size=kernel_size3, padding="same", name="conv_h3_1"
        )
    )
    mufilter_h_branch.add(BatchNormalization(name="batch_norm_h3_1"))
    mufilter_h_branch.add(ReLU())
    mufilter_h_branch.add(CBAM(name="CBAM_h3_1"))
    mufilter_h_branch.add(
        MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_h3_1")
    )
    mufilter_h_branch.add(Dropout(rate=drop_rate))

    mufilter_h_branch.add(
        Conv2D(
            filters=filters4, kernel_size=kernel_size4, padding="same", name="conv_h4_1"
        )
    )
    mufilter_h_branch.add(BatchNormalization(name="batch_norm_h4_1"))
    mufilter_h_branch.add(ReLU())
    mufilter_h_branch.add(CBAM(name="CBAM_h4_1"))
    mufilter_h_branch.add(
        MaxPooling2D(pool_size=(2, 2), padding="same", name="pool_h4_1")
    )
    mufilter_h_branch.add(Dropout(rate=drop_rate))

    mufilter_h_branch.add(Flatten(name="flatten_h_1"))

    # MU Filter Vertical Branch
    mufilter_v_branch = Sequential(name="mufilter_v_branch")
    mufilter_v_branch.add(Input(shape=(5, 4608, 1), name="mufilter_v_in"))

    mufilter_v_branch.add(
        Conv2D(
            filters=filters1, kernel_size=kernel_size1, padding="same", name="conv_v1_1"
        )
    )
    mufilter_v_branch.add(BatchNormalization(name="batch_norm_v1_1"))
    mufilter_v_branch.add(ReLU())
    mufilter_v_branch.add(CBAM(name="CBAM_v1_1"))
    mufilter_v_branch.add(
        MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_v1_1")
    )
    mufilter_v_branch.add(Dropout(rate=drop_rate))

    mufilter_v_branch.add(
        Conv2D(
            filters=filters2, kernel_size=kernel_size2, padding="same", name="conv_v2_1"
        )
    )
    mufilter_v_branch.add(BatchNormalization(name="batch_norm_v2_1"))
    mufilter_v_branch.add(ReLU())
    mufilter_v_branch.add(CBAM(name="CBAM_v2_1"))
    mufilter_v_branch.add(
        MaxPooling2D(pool_size=(2, 4), padding="valid", name="pool_v2_1")
    )
    mufilter_v_branch.add(Dropout(rate=drop_rate))

    mufilter_v_branch.add(
        Conv2D(
            filters=filters3, kernel_size=kernel_size3, padding="same", name="conv_v3_1"
        )
    )
    mufilter_v_branch.add(BatchNormalization(name="batch_norm_v3_1"))
    mufilter_v_branch.add(ReLU())
    mufilter_v_branch.add(CBAM(name="CBAM_v3_1"))
    mufilter_v_branch.add(
        MaxPooling2D(pool_size=(1, 4), padding="valid", name="pool_v3_1")
    )
    mufilter_v_branch.add(Dropout(rate=drop_rate))

    mufilter_v_branch.add(
        Conv2D(
            filters=filters4, kernel_size=kernel_size4, padding="same", name="conv_v4_1"
        )
    )
    mufilter_v_branch.add(BatchNormalization(name="batch_norm_v4_1"))
    mufilter_v_branch.add(ReLU())
    mufilter_v_branch.add(CBAM(name="CBAM_v4_1"))
    mufilter_v_branch.add(
        MaxPooling2D(pool_size=(1, 2), padding="same", name="pool_v4_1")
    )
    mufilter_v_branch.add(Dropout(rate=drop_rate))

    mufilter_v_branch.add(Flatten(name="flatten_v_1"))

    X = Concatenate(name="concat_branches")(
        [
            target_h_branch.output,
            target_v_branch.output,
            mufilter_v_branch.output,
            mufilter_v_branch.output,
        ]
    )
    X = Dense(4)(X)
    X = BatchNormalization()(X)
    X = ReLU()(X)
    X = Dense(20)(X)
    X = BatchNormalization()(X)
    X = ReLU()(X)
    X = Dropout(rate=0.2)(X)
    X = Dense(1)(X)

    model = Model(
        inputs=[
            target_h_branch.input,
            target_v_branch.input,
            mufilter_h_branch.input,
            mufilter_v_branch.input,
        ],
        outputs=X,
        name=model_name,
    )

    adam = Adam(
        learning_rate=hp.Float(
            "learning_rate", min_value=1e-4, max_value=1e-2, sampling="log"
        )
    )
    model.compile(optimizer=adam, loss="mse", metrics=["mae"])

    return model

In [None]:
tuner = kt.Hyperband(
    build_model,
    objective="val_mae",
    max_epochs=5,
    factor=3,
    directory="3D_hyperparam_opt",
    project_name=model_name,
)

In [None]:
tuner.search_space_summary()

In [None]:
tuner.search(
    train_subset,
    validation_data=val_subset,
    callbacks=[tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5)],
)

In [None]:
tuner.results_summary()

In [None]:
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
toy_model = tuner.hypermodel.build(best_hps)

In [None]:
toy_model.summary()

In [None]:
plot_model(toy_model)

In [None]:
# TODO activation for max pooling?
# TODO Reduce number of convolutional layers?
# TODO Add hidden hidden layer (or two?) before outputs?
# TODO predict independently?

In [None]:
fit_result = toy_model.fit(
    batched_ds_train.prefetch(tf.data.AUTOTUNE),
    epochs=5,
)

In [None]:
history_df = pd.concat([history_df, pd.DataFrame(fit_result.history)])

In [None]:
history_df.to_csv(f"history_{model_name}_n{n_events}_e{len(history_df)}.csv")

In [None]:
toy_model.save(f"{model_name}_n{n_events}_e{len(history_df)}.keras")

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
plt.title("CNN lepton + hadron energy")
ax1.plot(history_df["loss"].values, color=colors[0])
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss Function", color=colors[0])
try:
    ax2.plot(history_df["mae"].values, color=colors[1])
except KeyError:
    ax2.plot(history_df["dense_2_mae"].values, color=colors[1])
    ax2.plot(history_df["dense_3_mae"].values, color=colors[1])
ax2.set_ylabel("Error", color=colors[1])
plt.text(
    0.3,
    0.7,
    f"Training dataset: {events_train.num_entries} events\n"
    f"Test dataset: {events_test.num_entries} events\n"
    f"Training duration: {len(history_df)} epochs\n{model_name}",
    transform=ax1.transAxes,
)
plt.savefig(f"plots/convergence_{model_name}_n{n_events}_e{len(history_df)}.pdf")
plt.savefig(f"plots/convergence_{model_name}_n{n_events}_e{len(history_df)}.png")

In [None]:
# test=retoy_model.predict(x=[x_test['scifi_h'], x_test['scifi_v'], x_test['us'], x_test['ds']])
y_pred = toy_model.predict(batched_ds_test)

In [None]:
rms = tensorflow.keras.metrics.RootMeanSquaredError()
rms.update_state(y_test, y_pred)
rmse_value = rms.result().numpy()
print("Root Mean Squared Error:", rmse_value)

In [None]:
rms.reset_states()

In [None]:
# df = pd.DataFrame({"lepton_energy_pred" : np.squeeze(np.exp(y_pred)[0]), "lepton_energy_test" : np.squeeze(np.exp(y_test)[0]),
#                  "hadron_energy_pred" : np.squeeze(np.exp(y_pred)[1]), "hadron_energy_test" : np.squeeze(np.exp(y_test)[1])})

df = pd.DataFrame(
    {
        f"{target}_pred": np.squeeze(np.exp(y_pred)),
        f"{target}_test": np.squeeze(np.exp(y_test)),
    }
)
if "energy" not in target:
    df = pd.DataFrame(
        {f"{target}_pred": np.squeeze(y_pred), f"{target}_test": np.squeeze(y_test)}
    )

In [None]:
df.to_csv(f"{model_name}_n{n_events}_e{len(history_df)}.csv")

In [None]:
f"{model_name}_n{n_events}_e{len(history_df)}.keras"

In [None]:
h = hist.Hist.new.Regular(20, -30, +30, name=r"𝛥z [cm]").Double()

In [None]:
h.fill(np.squeeze(y_pred) - np.squeeze(y_test))

In [None]:
def model(x, mu, sigma):
    return scipy.stats.norm.cdf(x, mu, sigma)

In [None]:
entries, edges = h.to_numpy()

In [None]:
m = Minuit(cost.BinnedNLL(entries, edges, model), 0, 25)

In [None]:
res = m.migrad()

In [None]:
h.plot()
plt.xlabel(r"$\Delta z\;[\mathrm{cm}]$")
plt.title("3D CNN")
plot_range = edges[0], edges[-1]
x = np.linspace(*plot_range, 100)
best_fit = scipy.stats.norm(res.params[0].value, res.params[1].value)
n_bins = len(entries)
binsize = (plot_range[1] - plot_range[0]) / n_bins
scale = h.sum() / (best_fit.cdf(plot_range[1]) - best_fit.cdf(plot_range[0])) * binsize
plt.plot(x, scale * best_fit.pdf(x))

ax = plt.gca()
plt.text(
    0.6,
    0.9,
    rf"$\mu = {res.params[0].value:.2f} \pm {res.params[0].error:.3f}$\;cm",
    transform=ax.transAxes,
    usetex=True,
    fontsize=11,
)
plt.text(
    0.6,
    0.81,
    rf"$\sigma = {res.params[1].value:.2f} \pm {res.params[1].error:.3f}$\;cm",
    transform=ax.transAxes,
    usetex=True,
    fontsize=11,
)
plt.text(
    0.02,
    0.78,
    f"Training dataset: {events_train.num_entries} events\n"
    f"Test dataset: {events_test.num_entries} events\n"
    f"Training duration: {len(history_df)} epochs\n{model_name}",
    transform=ax.transAxes,
    usetex=True,
    fontsize=10,
)

plotting.watermark()
plt.savefig(f"plots/h_dz_{model_name}_n{n_events}_e{len(history_df)}.pdf")
plt.savefig(f"plots/h_dz_{model_name}_n{n_events}_e{len(history_df)}.png")