# CNN for AdvSND energy reconstruction

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import (
    Add,
    BatchNormalization,
    Concatenate,
    Conv2D,
    Dense,
    Dropout,
    Flatten,
    Input,
    Lambda,
    MaxPooling2D,
    RandomFlip,
    ReLU,
)
from tensorflow.keras.models import Model

from CBAM3D import CBAM
from config import input_shape, input_shape_mf
from losses import normalised_mse

In [None]:
K.set_image_data_format("channels_last")

In [None]:
model_name = "CNN_nadelhorn-nmse-5-flip_energy_combined"

In [None]:
def res_block(inputs):
    x = Conv2D(16, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(16, 3, padding="same")(x)
    x = BatchNormalization()(x)
    return Add()([inputs, x])

In [None]:
def conv_model(inputs, drop_middle=0.25, add_CBAM=False):
    x = Conv2D(16, kernel_size=(1, 9), padding="same")(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    if add_CBAM:
        x = CBAM()(X)
    x = MaxPooling2D(pool_size=(2, 4), padding="valid")(x)
    x = Dropout(rate=drop_middle)(x)

    x = Conv2D(16, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    if add_CBAM:
        x = CBAM()(X)
    x = MaxPooling2D(pool_size=(2, 4), padding="valid")(x)
    x = Dropout(rate=drop_middle)(x)

    x = Conv2D(16, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    if add_CBAM:
        x = CBAM()(X)
    x = MaxPooling2D(pool_size=(2, 4), padding="valid")(x)
    x = Dropout(rate=drop_middle)(x)

    x = Conv2D(16, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    if add_CBAM:
        x = CBAM()(X)
    x = MaxPooling2D(pool_size=(2, 2), padding="valid")(x)
    return x

In [None]:
def res_net(inputs, depth=5):
    x = res_block(inputs)
    for i in range(depth - 1):
        x = res_block(x)
    return x

In [None]:
def sum_input(input):
    x = Flatten()(input)
    return Lambda(
        lambda x: tf.reduce_sum(x, axis=1, keep_dims=True), output_shape=(1, 1, 1)
    )(x)

In [None]:
# def output_block(inputs):
#    return Dense(1)(inputs)

In [None]:
def output_block(inputs):
    X = Dense(3)(inputs)
    X = BatchNormalization()(X)
    X = ReLU()(X)
    X = Dense(20)(X)
    X = BatchNormalization()(X)
    X = ReLU()(X)
    X = Dropout(rate=0.2)(X)
    return Dense(1)(X)

In [None]:
lr = 2e-4

target_h_input = Input(input_shape, name="target_h_in")
# x_h = res_net(target_h_input)
x_h = RandomFlip(mode="vertical", seed=42)(target_h_input)
x_h = res_net(x_h)
x_h = Flatten()(x_h)

target_v_input = Input(input_shape, name="target_v_in")
# x_v = res_net(target_v_input)
x_v = RandomFlip(mode="vertical", seed=42)(
    target_v_input
)  # TODO check whether we can relax seed for one target view
x_v = res_net(x_v)
x_v = Flatten()(x_v)

mufilter_input = Input(input_shape_mf, name="mufilter_in")
# x_mf = res_net(mufilter_input)
x_mf = RandomFlip(mode="vertical", seed=42)(mufilter_input)
x_mf = res_net(x_mf)
x_mf = Flatten()(x_mf)

X = Concatenate()([x_h, x_v, x_mf])
# X = x_mf

# sum_h = sum_input(target_h_input)
# sum_v = sum_input(target_v_input)
# sum_mf = sum_input(mufilter_input)


# feats = Concatenate()([sum_h, sum_v, sum_mf])
# X = Dense(3)(X)
# X = Concatenate()([X, feats])
X = output_block(X)

model = Model(
    inputs=[target_h_input, target_v_input, mufilter_input],
    outputs=X,
    name=model_name,
)

K.clear_session()

model.compile(
    optimizer="Adam",
    loss=normalised_mse,
    metrics=[
        "mae",
    ],
)

In [None]:
model.save(f"{model_name}_e0.keras")