In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.constraints import max_norm
import numpy as np


# ----------------- FeedForward (MLP) -----------------
def feed_forward_block(x, hidden_dim, dropout=0.0, name=None):
    """Equivalent to PyTorch FeedForward (without the residual)."""
    with tf.name_scope(name or "FeedForward"):
        d = tf.keras.backend.int_shape(x)[-1]
        y = layers.LayerNormalization()(x)
        y = layers.Dense(hidden_dim, activation="gelu")(y)
        y = layers.Dropout(dropout)(y)
        y = layers.Dense(d)(y)
        y = layers.Dropout(dropout)(y)
    return y


# ----------------- Attention block -----------------
def attention_block(x, dim_head=64, heads=8, dropout=0.0, name=None):
    """
    Equivalent to PyTorch Attention(dim, heads, dim_head).
    x: (B, N, D) tokens x embedding
    """
    with tf.name_scope(name or "Attention"):
        D = tf.keras.backend.int_shape(x)[-1]
        mha = layers.MultiHeadAttention(
            num_heads=heads,
            key_dim=dim_head,
            dropout=dropout,
            output_shape=D  # project back to dim
        )
        y = mha(x, x)
        y = layers.Dropout(dropout)(y)
        # residual will be added outside, as in PyTorch (attn(x) + x)
    return y


# ----------------- 1D CNN block (fine-grained branch) -----------------
def cnn_1d_block(x, in_chan, kernel_size, dropout=0.0, name=None):
    """
    PyTorch:
      Dropout -> Conv1d(in_chan, in_chan, k, padding) -> BN -> ELU -> MaxPool1d(2)
    Input:  (B, in_chan, L)
    Output: (B, in_chan, L/2)
    """
    with tf.name_scope(name or "CNN1D_Block"):
        # Keras Conv1D with data_format='channels_first': (B, C, L)
        y = layers.Dropout(dropout)(x)
        y = layers.Conv1D(
            filters=in_chan,
            kernel_size=kernel_size,
            padding="same",
            data_format="channels_first",
            use_bias=True
        )(y)
        y = layers.BatchNormalization(axis=1)(y)  # axis=channel for channels_first
        y = layers.ELU()(y)
        y = layers.MaxPooling1D(pool_size=2, strides=2, data_format="channels_first")(y)
    return y


# ----------------- Helper: get_info (log(mean(x^2))) -----------------
def get_info_tensor(x):
    """
    PyTorch get_info:
        x: (b, k, l)
        x = log(mean(x^2, dim=-1))
        return (b, k)
    """
    return tf.math.log(
        tf.reduce_mean(tf.square(x), axis=-1) + 1e-12
    )


# ----------------- Main Deformer builder -----------------
def build_eeg_deformer_keras(
    num_chan,
    num_time,
    temporal_kernel,
    num_kernel=64,
    num_classes=2,
    depth=4,
    heads=16,
    mlp_dim=16,
    dim_head=16,
    dropout=0.5,
):
    """
    Keras implementation of Deformer.

    Args
    ----
    num_chan        : number of EEG channels (C)
    num_time        : number of time samples (T)
    temporal_kernel : temporal kernel (same as PyTorch 'temporal_kernel')
    num_kernel      : number of CNN kernels in first 2D conv block (64 in original)
    num_classes     : output classes
    depth           : number of Transformer layers
    heads           : attention heads
    mlp_dim         : hidden dim for FFN inside transformer
    dim_head        : per-head key dimension
    dropout         : dropout rate

    Input shape
    -----------
    EEG input is expected as (batch, num_chan, num_time)
    """

    # ----------------- Input -----------------
    inp = layers.Input(shape=(num_chan, num_time), name="eeg_input")
    # PyTorch unsqueezes along dim=1 => (B, 1, chan, time)
    x = layers.Lambda(lambda t: tf.expand_dims(t, axis=1))(inp)  # (B, 1, C, T)

    # ----------------- CNN Encoder (2D) -----------------
    # PyTorch:
    #   Conv2dWithConstraint(1, num_kernel, (1, temporal_kernel), padding=(0, pad))
    #   Conv2dWithConstraint(num_kernel, num_kernel, (num_chan, 1), padding=0)
    #   BN, ELU, MaxPool2d((1,2))
    pad_t = temporal_kernel // 2  # same as get_padding(kernel) in time dim

    # First conv: kernel (1, temporal_kernel)
    x = layers.Conv2D(
        filters=num_kernel,
        kernel_size=(1, temporal_kernel),
        padding="same",   # temporal padding; equivalent behavior
        data_format="channels_first",  # (B, 1, C, T)
        use_bias=True,
        kernel_constraint=max_norm(2.0)
    )(x)

    # Second conv: kernel (num_chan, 1) collapses spatial/channel dim
    x = layers.Conv2D(
        filters=num_kernel,
        kernel_size=(num_chan, 1),
        padding="valid",
        data_format="channels_first",   # (B, num_kernel, 1, T)
        use_bias=True,
        kernel_constraint=max_norm(2.0)
    )(x)

    x = layers.BatchNormalization(axis=1)(x)
    x = layers.ELU()(x)
    x = layers.MaxPool2D(
        pool_size=(1, 2),
        strides=(1, 2),
        data_format="channels_first"
    )(x)  # (B, num_kernel, 1, 0.5*num_time)

    # After this, in PyTorch: dim = int(0.5*num_time)
    dim = int(0.5 * num_time)

    # ----------------- To patch embedding: Rearrange 'b k c f -> b k (c f)' -----------------
    # Current shape: (B, num_kernel, 1, dim)
    x = layers.Permute((1, 3, 2))(x)        # (B, num_kernel, dim, 1)
    x = layers.Reshape((num_kernel, dim))(x)  # (B, num_kernel, dim) = (B, k, d)

    # x will now be passed to "Transformer" equivalent
    # We treat shape as (B, in_chan=num_kernel, L=dim) for CNN, and (B, tokens=num_kernel, D=dim) for attention.

    dense_features = []
    # We'll emulate the iterative Transformer.forward:
    #   for attn, ff, cnn in layers:
    #       x_cg = pool(x)
    #       x_cg = attn(x_cg) + x_cg
    #       x_fg = cnn(x)
    #       x_info = get_info(x_fg)
    #       dense_feature.append(x_info)
    #       x = ff(x_cg) + x_fg

    # Shared coarse MaxPool1D (same as self.pool in PyTorch)
    pool_coarse = layers.MaxPooling1D(
        pool_size=2, strides=2,
        data_format="channels_first"  # treats shape as (B, C, L)
    )

    # We'll keep x as (B, num_kernel, current_dim) throughout.
    for i in range(depth):
        # Coarse-grained branch (pool then attention)
        x_cg = pool_coarse(x)          # (B, num_kernel, dim/2)
        # Attention expects (B, tokens, features) = same shape.
        attn_out = attention_block(
            x_cg,
            dim_head=dim_head,
            heads=heads,
            dropout=dropout,
            name=f"attn_block_{i}"
        )
        x_cg = layers.Add()([attn_out, x_cg])  # attn(x_cg) + x_cg

        # Fine-grained CNN branch
        x_fg = cnn_1d_block(
            x,
            in_chan=num_kernel,
            kernel_size=temporal_kernel,
            dropout=dropout,
            name=f"cnn1d_block_{i}"
        )  # (B, num_kernel, dim/2) same length as x_cg

        # Info feature from x_fg: (B, num_kernel)
        x_info = layers.Lambda(get_info_tensor, name=f"info_{i}")(x_fg)
        dense_features.append(x_info)

        # FeedForward on coarse branch, then combine with fine-grained
        ff_out = feed_forward_block(
            x_cg,
            hidden_dim=mlp_dim,
            dropout=dropout,
            name=f"ff_block_{i}"
        )
        x = layers.Add()([ff_out, x_fg])   # ff(x_cg) + x_fg

    # Concatenate dense features: list of (B, num_kernel) -> (B, num_kernel * depth)
    if len(dense_features) > 1:
        x_dense = layers.Concatenate(axis=-1)(dense_features)
    else:
        x_dense = dense_features[0]

    # Flatten x from last layer: (B, num_kernel, d_last) -> (B, num_kernel * d_last)
    x_flat = layers.Flatten()(x)

    # Final embedding: concat [x_flat, x_dense]
    emd = layers.Concatenate(axis=-1, name="embedding_concat")([x_flat, x_dense])

    # ----------------- MLP Head -----------------
    out = layers.Dense(num_classes, activation="softmax", name="logits")(emd)

    model = Model(inputs=inp, outputs=out, name="EEG_Deformer_Keras")
    return model


# # ----------------- Example usage -----------------
# if __name__ == "__main__":
#     # Match the PyTorch example:
#     # data = torch.ones((16, 32, 1000))
#     batch_size = 16
#     num_chan = 32
#     num_time = 1000

#     model = build_eeg_deformer_keras(
#         num_chan=num_chan,
#         num_time=num_time,
#         temporal_kernel=11,
#         num_kernel=64,
#         num_classes=2,
#         depth=4,
#         heads=16,
#         mlp_dim=16,
#         dim_head=16,
#         dropout=0.5,
#     )

#     model.summary()

#     # Dummy forward
#     x = np.ones((batch_size, num_chan, num_time), dtype=np.float32)
#     y = model(x)
#     print("Output shape:", y.shape)
