In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model


# --------- Transformer utilities ---------

def transformer_block(x, dim, heads, dim_head, mlp_dim, dropout):
    """
    One Transformer encoder block with:
    - PreNorm + Multi-Head Self Attention + residual
    - PreNorm + FeedForward MLP + residual

    x: (B, tokens, dim)
    """
    # --- Self-attention with PreNorm ---
    x_norm1 = layers.LayerNormalization()(x)
    attn = layers.MultiHeadAttention(
        num_heads=heads,
        key_dim=dim_head,
        dropout=dropout,
    )(x_norm1, x_norm1)                         # (B, tokens, heads*dim_head)
    # Project back to dim (like to_out in PyTorch)
    attn = layers.Dense(dim)(attn)
    attn = layers.Dropout(dropout)(attn)
    x = layers.Add()([x, attn])

    # --- Feed-forward with PreNorm ---
    x_norm2 = layers.LayerNormalization()(x)
    ff = layers.Dense(mlp_dim, activation="gelu")(x_norm2)
    ff = layers.Dropout(dropout)(ff)
    ff = layers.Dense(dim)(ff)
    ff = layers.Dropout(dropout)(ff)
    x = layers.Add()([x, ff])

    return x


class AddClassTokenAndPosEmbedding(layers.Layer):
    """
    Adds a learned [CLS] token and positional embeddings.

    Input:  (B, num_patches, dim)
    Output: (B, num_patches+1, dim)
    """
    def __init__(self, num_patches, dim, emb_dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_patches = num_patches
        self.dim = dim
        self.dropout = layers.Dropout(emb_dropout)

    def build(self, input_shape):
        # pos_embedding: (1, num_patches+1, dim)
        self.pos_embedding = self.add_weight(
            name="pos_embedding",
            shape=(1, self.num_patches + 1, self.dim),
            initializer="random_normal",
            trainable=True,
        )
        # cls_token: (1, 1, dim)
        self.cls_token = self.add_weight(
            name="cls_token",
            shape=(1, 1, self.dim),
            initializer="random_normal",
            trainable=True,
        )
        super().build(input_shape)

    def call(self, x, training=None):
        # x: (B, num_patches, dim)
        B = tf.shape(x)[0]
        cls_tokens = tf.repeat(self.cls_token, repeats=B, axis=0)   # (B, 1, dim)
        x = tf.concat([cls_tokens, x], axis=1)                      # (B, num_patches+1, dim)

        # Add positional embedding (broadcast over batch)
        x = x + self.pos_embedding[:, :tf.shape(x)[1], :]

        x = self.dropout(x, training=training)
        return x


def build_eeg_vit_keras(
    num_chan,
    num_time,
    num_patches,
    num_classes,
    dim=32,
    depth=4,
    heads=16,
    mlp_dim=64,
    pool="cls",            # 'cls' or 'mean'
    dim_head=64,
    dropout=0.1,
    emb_dropout=0.1,
):
    """
    Keras version of EEGViT.

    PyTorch original:
      - Input: (B, num_chan, num_time)
      - num_patches * patch_len = num_time

    Here we use the same input convention: (B, num_chan, num_time).
    """

    assert pool in {"cls", "mean"}, "pool must be either 'cls' or 'mean'"
    assert num_time % num_patches == 0, "num_time must be divisible by num_patches"

    patch_len = num_time // num_patches  # l in original code

    # Input: (B, num_chan, num_time)
    inp = layers.Input(shape=(num_chan, num_time), name="eeg_input")

    # ---- Patch embedding: Rearrange('b c (n l) -> b n (c l)'); then Linear(c*l -> dim) ----
    def to_patches(t):
        # t: (B, C, T)
        B = tf.shape(t)[0]
        C = tf.shape(t)[1]
        T = tf.shape(t)[2]

        # Reshape time axis into (num_patches, patch_len)
        t = tf.reshape(t, (B, C, num_patches, patch_len))  # (B, C, n, l)
        t = tf.transpose(t, perm=[0, 2, 1, 3])             # (B, n, C, l)
        t = tf.reshape(t, (B, num_patches, C * patch_len)) # (B, n, C*l)
        return t

    x = layers.Lambda(to_patches, name="to_patches")(inp)       # (B, num_patches, C*l)

    x = layers.Dense(dim, name="patch_linear")(x)               # (B, num_patches, dim)

    # ---- Add [CLS] token + positional embedding + dropout ----
    x = AddClassTokenAndPosEmbedding(
        num_patches=num_patches,
        dim=dim,
        emb_dropout=emb_dropout,
        name="cls_pos_embedding"
    )(x)                                                        # (B, num_patches+1, dim)

    # ---- Transformer encoder stack ----
    for i in range(depth):
        x = transformer_block(
            x,
            dim=dim,
            heads=heads,
            dim_head=dim_head,
            mlp_dim=mlp_dim,
            dropout=dropout,
        )

    # ---- Pooling: CLS token or mean ----
    if pool == "mean":
        x = tf.reduce_mean(x, axis=1)         # (B, dim)
    else:
        x = x[:, 0, :]                        # (B, dim)  CLS token

    # ---- Final MLP head ----
    x = layers.LayerNormalization(name="final_norm")(x)
    logits = layers.Dense(num_classes, name="logits")(x)

    model = Model(inputs=inp, outputs=logits, name="EEGViT")
    return model


def compile_eeg_vit(model, num_classes=2, lr=1e-3):
    """Helper to compile like your other models."""
    if num_classes > 1:
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    else:
        loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        metrics = [tf.keras.metrics.BinaryAccuracy(name="accuracy")]

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss=loss,
        metrics=metrics,
    )
    return model

