In [None]:
############### AFNet (accepts (K, M, S, 1); architecture unchanged) ###############
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, BatchNormalization, Activation,
    AveragePooling2D, Dropout, Dense, SeparableConv2D, Add,
    GlobalAveragePooling2D, GlobalAveragePooling1D,
    MultiHeadAttention, Reshape, LayerNormalization, Multiply
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K


# -------- Utilities: safe temporal parameters (no architecture change) --------
def _cap_kernel(k: int, S: int) -> int:
    """Ensure temporal kernel length does not exceed segment length S."""
    return max(1, min(int(k), int(S)))

def _pick_safe_temporal_pools(S: int,
                              preferred=(4, 4)):
    """
    Keep the two AveragePooling2D layers but choose pool sizes so that after
    two temporal pools the time axis remains >= 1.
      L1 = floor(S / p1), L2 = floor(L1 / p2) >= 1  â‡’  S >= p1 * p2
    Returns ((1, p1), (1, p2)).
    """
    candidates = [
        preferred, (4, 4), (4, 3), (3, 3), (3, 2), (2, 2), (2, 1), (1, 1)
    ]
    seen, uniq = set(), []
    for c in candidates:
        if c not in seen:
            uniq.append(c); seen.add(c)
    for p1, p2 in uniq:
        if (S // p1) // p2 >= 1:
            return (1, p1), (1, p2)
    return (1, 1), (1, 1)


# -------- Spatial Attention (unchanged) --------
def SpatialAttention(x, name="spatial_attn"):
    """
    Spatial attention to focus on the most relevant EEG electrodes.
    Returns x * attn where attn is (B, M, 1, 1). We also name the attn tensor.
    """
    attn = GlobalAveragePooling2D(name=f"{name}_gap")(x)          # (B, C)
    attn = Dense(64, activation='relu', name=f"{name}_fc1")(attn)

    M_elec = K.int_shape(x)[1]
    if M_elec is None:
        raise ValueError("Electrode dimension must be known. Pass input_shape=(M, S, 1).")

    attn = Dense(int(M_elec), activation='sigmoid', name=f"{name}_weights")(attn)  # (B, M)
    attn_map = Reshape((int(M_elec), 1, 1), name=f"{name}_map")(attn)              # (B, M, 1, 1)

    out = Multiply(name=f"{name}_apply")([x, attn_map])
    return out


# -------- Transformer Block (unchanged) --------
def TransformerBlock(x, num_heads=4, key_dim=64, ff_dim=128, dropout_rate=0.1):
    """
    Standard Transformer encoder block for sequences (B, L, D).
    """
    attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)(x, x)
    attn_output = Dropout(dropout_rate)(attn_output)
    attn_output = Add()([x, attn_output])
    attn_output = LayerNormalization()(attn_output)

    ff = Dense(ff_dim, activation='relu')(attn_output)
    ff = Dropout(dropout_rate)(ff)
    ff = Dense(K.int_shape(x)[-1])(ff)
    x = Add()([attn_output, ff])
    x = LayerNormalization()(x)
    return x


# -------- Model (architecture unchanged; parameters adapted to M,S) --------
def EEGNet_SpatialTransformer(input_shape=(22, 100, 1),
                              dropout_rate=0.5,
                              num_heads=4, ff_dim=128,
                              num_classes=4):
    """
    AFNet-style EEG model with Spatial Attention + Transformer encoder.
    Accepts input_shape=(M, S, 1) with known M and S.
    """
    M, S, C = input_shape
    if C != 1:
        raise ValueError("This model expects a single input channel. Use input_shape=(M, S, 1).")

    # Safe temporal hyperparameters (no architecture change)
    k1 = _cap_kernel(5, S)     # for SeparableConv2D (1, 5)
    k2 = _cap_kernel(3, S)     # for SeparableConv2D (1, 3)
    pool1, pool2 = _pick_safe_temporal_pools(S, preferred=(8, 4))

    inputs = Input(shape=input_shape)  # (B, M, S, 1)

    # 1) Temporal-focused SeparableConv2D
    x = SeparableConv2D(32, (1, k1), padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # 2) Spatial Attention
    x = SpatialAttention(x)

    # 3) Spatial filtering across electrodes; keep M dimension
    x = DepthwiseConv2D((M, 1), use_bias=False, depth_multiplier=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # 4) Residual projection to match channels for Add()
    ch_out = K.int_shape(x)[-1] or 64  # should be 32*depth_multiplier = 64
    proj = Conv2D(filters=ch_out, kernel_size=(1, 1), padding='same', use_bias=False)(inputs)
    proj = BatchNormalization()(proj)
    x = Add()([x, proj])
    x = BatchNormalization()(x)

    # 5) Downsample (temporal)
    x = AveragePooling2D(pool_size=pool1)(x)
    x = Dropout(dropout_rate)(x)

    # 6) Separable conv
    x = SeparableConv2D(64, (1, k2), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # 7) Downsample again (temporal)
    x = AveragePooling2D(pool_size=pool2)(x)
    x = Dropout(dropout_rate)(x)

    # 8) Prepare for Transformer: reshape to (B, M, S'*C')
    _, M_eff, S_eff, C_eff = K.int_shape(x)
    if None in (M_eff, S_eff, C_eff):
        raise ValueError("Static shapes must be known. Pass concrete input_shape=(M, S, 1).")
    x = Reshape((M_eff, S_eff * C_eff))(x)

    # 9) Transformer encoder
    x = TransformerBlock(x, num_heads=num_heads, key_dim=32, ff_dim=ff_dim, dropout_rate=0.1)

    # 10) Pool over sequence (electrodes)
    x = GlobalAveragePooling1D()(x)
    x = BatchNormalization()(x)

    # 11) MLP head
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(dropout_rate)(x)

    # 12) Output
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs, name="EEGNet_SpatialTransformer_refit")


# -------- Instantiate & Compile (no normalisation of inputs here) --------
# Example:
# M = X_train.shape[1]; S = X_train.shape[2]; C = X_train.shape[3]  # must be 1
# n_classes = int(max(y_train.max(), y_val.max())) + 1
# model = EEGNet_SpatialTransformer(input_shape=(M, S, 1),
#                                   dropout_rate=0.5,
#                                   num_heads=4, ff_dim=128,
#                                   num_classes=n_classes)
# model.compile(optimizer=Adam(learning_rate=1e-3),
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(),
#               metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])

# from tensorflow.keras.callbacks import EarlyStopping
# earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=15,
#                                              restore_best_weights=True, verbose=1)
# history = model.fit(X_train, y_train,
#                     validation_data=(X_val, y_val),
#                     batch_size=50, epochs=150,
#                     callbacks=[earlystop], verbose=1)
# model.summary()

