# Model Definition

In [1]:
# keras >= 2.9
import tensorflow as tf
from tensorflow.keras import layers, models

# ----- helpers -----
def crelu(x):
    return layers.Concatenate(axis=-1)([layers.ReLU()(x), layers.ReLU()(-x)])

def sep_res_block(x, filters, ksize=9, stride=1, name="b"):
    # depthwise separable conv via SeparableConv1D
    y = layers.SeparableConv1D(filters, ksize, strides=stride, padding="same", use_bias=False, name=f"{name}_sep")(x)
    y = layers.BatchNormalization(name=f"{name}_bn")(y)
    y = crelu(y)
    # match shape for residual
    if stride != 1 or x.shape[-1] != filters:
        skip = layers.Conv1D(filters, 1, strides=stride, padding="same", use_bias=False, name=f"{name}_proj")(x)
        skip = layers.BatchNormalization(name=f"{name}_proj_bn")(skip)
    else:
        skip = x
    out = layers.Add(name=f"{name}_add")([y, skip])
    return out

def build_model(input_length, n_channels, n_classes_main, use_aux_or_location=True, n_or_locations=3):
    inp = layers.Input(shape=(input_length, n_channels), name="signal")

    # ----- stem -----
    x = layers.Conv1D(32, 64, strides=4, padding="same", use_bias=False, name="stem_conv")(inp)
    x = layers.BatchNormalization(name="stem_bn")(x)
    x = crelu(x)  # efficient first layer feature doubling

    # ----- stages -----
    x = sep_res_block(x, 64,  ksize=9, stride=2, name="s1_b1")
    x = sep_res_block(x, 64,  ksize=9, stride=1, name="s1_b2")

    x = sep_res_block(x, 128, ksize=9, stride=2, name="s2_b1")
    x = sep_res_block(x, 128, ksize=9, stride=1, name="s2_b2")

    x = sep_res_block(x, 128, ksize=9, stride=1, name="s3_b1")
    x = sep_res_block(x, 128, ksize=9, stride=1, name="s3_b2")

    # ----- heads -----
    x = layers.GlobalAveragePooling1D(name="gap")(x)
    h = layers.Dense(100, activation="relu", name="neck")(x)

    main_out = layers.Dense(n_classes_main, activation="softmax", name="cls_main")(h)

    outputs = [main_out]

    if use_aux_or_location:
        aux_out = layers.Dense(n_or_locations, activation="softmax", name="cls_or_loc")(h)
        outputs.append(aux_out)

    return models.Model(inp, outputs, name="LDR_1D_CNN")

# ----- compile with masked auxiliary loss -----
def compile_with_aux_mask(model, aux_weight=0.2):
    # y_true for main is one hot. For aux, pass an extra mask channel at the end.
    # Shape for aux y_true: [..., n_or_locations + 1], where last position is 1.0 if label present else 0.0
    def masked_ce(y_true, y_pred):
        y, m = y_true[..., :-1], y_true[..., -1:]
        ce = tf.keras.losses.categorical_crossentropy(y, y_pred)
        ce = ce * tf.squeeze(m, axis=-1)
        denom = tf.maximum(tf.reduce_mean(m), 1e-6)
        return tf.reduce_sum(ce) / denom

    losses = {"cls_main": "categorical_crossentropy"}
    loss_weights = {"cls_main": 1.0}

    if "cls_or_loc" in [o.name for o in model.outputs]:
        losses["cls_or_loc"] = masked_ce
        loss_weights["cls_or_loc"] = aux_weight

    model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
                  loss=losses, loss_weights=loss_weights,
                  metrics={"cls_main": ["accuracy"]})
    return model
