In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model


# ----------------- TSception (Keras, channels_first) -----------------

def tsception_conv_block(x, out_chan, kernel, step, pool, name_prefix=None):
    """
    x      : (B, C_in, H, W)  with data_format='channels_first'
    kernel : (kH, kW)
    step   : (sH, sW)
    pool   : scalar -> AvgPool over (1, pool)
    """
    x = layers.Conv2D(
        filters=out_chan,
        kernel_size=kernel,
        strides=step,
        padding="valid",
        data_format="channels_first",
        use_bias=True,
        name=None if name_prefix is None else name_prefix + "_conv",
    )(x)
    x = layers.LeakyReLU(name=None if name_prefix is None else name_prefix + "_lrelu")(x)
    x = layers.AveragePooling2D(
        pool_size=(1, pool),
        strides=(1, pool),
        data_format="channels_first",
        name=None if name_prefix is None else name_prefix + "_avgpool",
    )(x)
    return x


def build_tsception_keras(
    num_classes,
    input_size,        # (1, M, T)  -> (freq, channels, time)
    sampling_rate,
    num_T,             # temporal filters
    num_S,             # spatial filters
    hidden,            # hidden units in FC
    dropout_rate,
):
    """
    Keras implementation of TSception.

    input_size : (1, M, T)
        1 frequency band, M channels, T time points
    sampling_rate : Hz
    """
    F, M, T = input_size
    assert F == 1, "TSceptionKeras assumes a single 'freq band' dimension (F=1)."

    # We follow PyTorch: input to network is (B, 1, chan, time) = (B, 1, M, T)
    inp = layers.Input(shape=(1, M, T))  # channels_first

    inception_window = [0.5, 0.25, 0.125]
    base_pool = 8

    # ---------- Temporal branches ----------
    k1 = int(inception_window[0] * sampling_rate)
    k2 = int(inception_window[1] * sampling_rate)
    k3 = int(inception_window[2] * sampling_rate)

    x = inp

    # Tception1,2,3: kernel=(1, k), stride=1, pool=8
    t1 = tsception_conv_block(
        x, out_chan=num_T, kernel=(1, k1),
        step=(1, 1), pool=base_pool, name_prefix="T1"
    )
    t2 = tsception_conv_block(
        x, out_chan=num_T, kernel=(1, k2),
        step=(1, 1), pool=base_pool, name_prefix="T2"
    )
    t3 = tsception_conv_block(
        x, out_chan=num_T, kernel=(1, k3),
        step=(1, 1), pool=base_pool, name_prefix="T3"
    )

    # Concatenate along temporal dimension (last axis for channels_first: B,C,H,W -> axis=3)
    t_out = layers.Concatenate(axis=3, name="T_concat")([t1, t2, t3])

    # BatchNorm over channel axis (axis=1 in channels_first)
    t_out = layers.BatchNormalization(axis=1, name="BN_t")(t_out)

    # ---------- Spatial branches ----------
    # pool for spatial modules: int(base_pool * 0.25)
    s_pool = int(base_pool * 0.25)

    # Sception1: kernel=(M,1), stride=(1,1)
    s1 = tsception_conv_block(
        t_out, out_chan=num_S,
        kernel=(int(M), 1),
        step=(1, 1),
        pool=s_pool,
        name_prefix="S1"
    )

    # Sception2: kernel=(M*0.5,1), stride=(M*0.5,1)
    half_M = int(M * 0.5)
    s2 = tsception_conv_block(
        t_out, out_chan=num_S,
        kernel=(half_M, 1),
        step=(half_M, 1),
        pool=s_pool,
        name_prefix="S2"
    )

    # Concatenate along channel-axis (H dimension here): (B, num_S, H, W) -> axis=2
    s_out = layers.Concatenate(axis=2, name="S_concat")([s1, s2])

    s_out = layers.BatchNormalization(axis=1, name="BN_s")(s_out)

    # ---------- Fusion layer ----------
    # fusion_layer: kernel=(3,1), stride=1, pool=4
    fusion = tsception_conv_block(
        s_out, out_chan=num_S,
        kernel=(3, 1),
        step=(1, 1),
        pool=4,
        name_prefix="fusion"
    )

    fusion = layers.BatchNormalization(axis=1, name="BN_fusion")(fusion)

    # ---------- Global pooling + FC ----------
    # Mean over spatial dims -> (B, num_S)
    feat = layers.GlobalAveragePooling2D(data_format="channels_first", name="global_avg")(fusion)

    x = layers.Dense(hidden, activation="relu", name="fc1")(feat)
    x = layers.Dropout(dropout_rate, name="fc1_drop")(x)
    out = layers.Dense(num_classes, activation="softmax", name="fc_out")(x)

    model = Model(inp, out, name="TSceptionKeras")
    return model


# ----------------- Helper: reshape (B,M,T,1) -> (B,1,M,T) -----------------

def to_tsception_input(X):
    """
    Convert (B, M, T, 1) -> (B, 1, M, T) for TSceptionKeras.
    """
    assert X.ndim == 4 and X.shape[-1] == 1, f"Expected X (B,M,T,1), got {X.shape}"
    X = np.squeeze(X, axis=-1)        # (B, M, T)
    X = np.expand_dims(X, axis=1)     # (B, 1, M, T)
    return X.astype("float32")


