In [None]:
# keras_conformer.py
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model

# ----------------- Transformer block (Keras) -----------------
def transformer_encoder_block(x, emb_size=40, num_heads=10, ff_mult=4, dropout=0.5, name=None):
    with tf.name_scope(name or "TransformerEncoderBlock"):
        # Self-attention
        x1 = layers.LayerNormalization()(x)
        attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=emb_size)(x1, x1)
        attn = layers.Dropout(dropout)(attn)
        x = layers.Add()([x, attn])

        # Feed-forward
        x2 = layers.LayerNormalization()(x)
        ff = layers.Dense(ff_mult * emb_size, activation="gelu")(x2)
        ff = layers.Dropout(dropout)(ff)
        ff = layers.Dense(emb_size)(ff)
        ff = layers.Dropout(dropout)(ff)
        x = layers.Add()([x, ff])
        return x


######################## SETUP #############################

# ----------------- Conformer (Keras) -----------------
def build_conformer_keras(
    input_shape=(61, 100, 1),
    num_classes=2,
    emb_size=40,
    depth=6,
    num_heads=10,
    dropout=0.5,
    ff_mult=4,
    flatten_head_for_1000=False,
):
    """
    Conformer in Keras for EEG.
    input_shape: (M, N, 1)   M=eeg channels, N=time samples
    Works for any (M, N, 1) by adapting temporal kernel/stride safely.
    """
    M, N, C = input_shape
    assert C == 1, "Input must be (M, N, 1)."

    # ---- Safe temporal hyperparams (no architecture change) ----
    # keep same operations but cap by N so shapes never collapse to 0
    kernel_t = max(1, min(25, N))   # original was 25
    pool_t   = max(1, min(75, N))   # original was 75
    stride_t = max(1, min(15, pool_t))  # original was 15

    inp = layers.Input(shape=input_shape)  # (B, M, N, 1)

    # --- PatchEmbedding (same sequence of layers as before) ---
    x = layers.Conv2D(
        40, kernel_size=(1, kernel_t), strides=(1, 1),
        padding="valid", use_bias=True
    )(inp)
    x = layers.Conv2D(
        40, kernel_size=(M, 1), strides=(1, 1),
        padding="valid", use_bias=True
    )(x)  # collapse M
    x = layers.BatchNormalization()(x)
    x = layers.ELU()(x)
    x = layers.AveragePooling2D(
        pool_size=(1, pool_t),
        strides=(1, stride_t),
        padding="valid"
    )(x)
    x = layers.Dropout(dropout)(x)

    # Projection to emb_size and turn into token sequence
    x = layers.Conv2D(
        emb_size, kernel_size=(1, 1),
        strides=(1, 1), padding="valid", use_bias=True
    )(x)
    # x shape: (B, 1, W_tokens, emb_size)
    x = layers.Permute((2, 3, 1))(x)      # (B, W_tokens, emb_size, 1)
    x = layers.Reshape((-1, emb_size))(x) # (B, tokens, emb_size)

    # --- Transformer encoder stack ---
    for _ in range(depth):
        x = transformer_encoder_block(
            x, emb_size=emb_size, num_heads=num_heads, ff_mult=ff_mult, dropout=dropout
        )

    # --- Classification head ---
    if flatten_head_for_1000:
        # still allowed for arbitrary shapes; just flattens whatever token length we get
        x_flat = layers.Flatten()(x)
        x = layers.Dense(256, activation="elu")(x_flat)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(32, activation="elu")(x)
        x = layers.Dropout(0.3)(x)
        out = layers.Dense(num_classes, activation="softmax")(x)
    else:
        # token mean pooling (shape-agnostic)
        x = layers.GlobalAveragePooling1D()(x)          # (B, emb_size)
        x = layers.LayerNormalization()(x)
        x = layers.Dense(256, activation="elu")(x)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(32, activation="elu")(x)
        x = layers.Dropout(0.3)(x)
        out = layers.Dense(num_classes, activation="softmax")(x)

    model = Model(inp, out, name="EEG_Conformer_Keras")
    return model


# ----------------- Minimal runner -----------------
def train_conformer_keras(
    X, y,
    X1, y1,
    input_shape=None,
    num_classes=2,
    epochs=200,
    batch_size=50,
    lr=1e-3,
    use_flatten_head=False,
):
    """
    X  : (K, M, N, 1) float32  (train)
    y  : (K,) int labels
    X1 : (K_val, M, N, 1)      (val/test)
    y1 : (K_val,) int labels

    If labels are {1,2}, they are shifted to {0,1}.
    input_shape will be inferred from X if not provided.
    """
    assert X.ndim == 4 and X.shape[3] == 1, f"Expected X (K,M,N,1), got {X.shape}"
    assert X1.ndim == 4 and X1.shape[3] == 1, f"Expected X1 (K,M,N,1), got {X1.shape}"

    if input_shape is None:
        input_shape = X.shape[1:]   # (M, N, 1)
    else:
        assert X.shape[1:] == input_shape, f"X shape {X.shape[1:]} != input_shape {input_shape}"
        assert X1.shape[1:] == input_shape, f"X1 shape {X1.shape[1:]} != input_shape {input_shape}"

    # Make labels 0-based if needed
    y  = y.astype("int64").ravel()
    y1 = y1.astype("int64").ravel()
    if y.min() == 1 and y1.min() == 1:
        y  = y  - 1
        y1 = y1 - 1

    # Build and compile
    model = build_conformer_keras(
        input_shape=input_shape,
        num_classes=num_classes,
        flatten_head_for_1000=use_flatten_head,
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="acc")]
    )

    earlystop = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=20,
        restore_best_weights=True, verbose=1
    )
    lr_plateau = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1
    )

    history = model.fit(
        X, y,
        validation_data=(X1, y1),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=[earlystop, lr_plateau],
        verbose=1
    )
    return model, history
