In [None]:
import math
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model

# ----------------- Graph Convolution -----------------

class GraphConvolution(layers.Layer):
    """
    Simple GCN layer:
      x:   (B, N, in_features)
      adj: (B, N, N)
      out: (B, N, out_features)
    """
    def __init__(self, in_features, out_features, use_bias=True, **kwargs):
        super().__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias

    def build(self, input_shape):
        # input_shape = [x_shape, adj_shape]
        self.weight = self.add_weight(
            shape=(self.in_features, self.out_features),
            initializer=tf.keras.initializers.GlorotUniform(),
            trainable=True,
            name="weight"
        )
        if self.use_bias:
            self.bias = self.add_weight(
                shape=(1, 1, self.out_features),
                initializer="zeros",
                trainable=True,
                name="bias"
            )
        else:
            self.bias = None
        super().build(input_shape)

    def call(self, inputs):
        x, adj = inputs  # x: (B,N,in), adj: (B,N,N)

        out = tf.matmul(x, self.weight)  # (B,N,out)
        if self.bias is not None:
            out = out - self.bias        # matches original code
        out = tf.matmul(adj, out)        # (B,N,out)
        out = tf.nn.relu(out)
        return out


# ----------------- Power Layer -----------------

class PowerLayer(layers.Layer):
    """
    Log-transformed power:
      - squares input
      - AvgPool2D over time
      - log()
    Input format: channels_first 4D: (B, C, H, W)
    """
    def __init__(self, length, step, **kwargs):
        super().__init__(**kwargs)
        self.length = length
        self.step = step
        self.pool = layers.AveragePooling2D(
            pool_size=(1, self.length),
            strides=(1, self.step),
            data_format="channels_first"
        )

    def call(self, x):
        x2 = tf.math.square(x)
        pooled = self.pool(x2)
        return tf.math.log(pooled + 1e-8)


# ----------------- Aggregator -----------------

class Aggregator(layers.Layer):
    """
    Aggregates channels into brain areas by averaging.
    idx_area: list of number of channels in each brain area, e.g. [4, 6, 8, ...]
    """
    def __init__(self, idx_area, **kwargs):
        super().__init__(**kwargs)
        self.chan_in_area = idx_area
        self.idx = self._get_idx(self.chan_in_area)
        self.area = len(idx_area)

    def _get_idx(self, chan_in_area):
        idx = [0] + chan_in_area
        idx_ = [0]
        for i in idx:
            idx_.append(idx_[-1] + i)
        # same as PyTorch: idx_[1:] are region boundaries (starts)
        return idx_[1:]

    def call(self, x):
        """
        x: (B, channels, features)
        returns: (B, num_areas, features)
        """
        data = []
        B = tf.shape(x)[0]
        C = tf.shape(x)[1]

        for i in range(self.area):
            start = self.idx[i]
            end = self.idx[i + 1] if i < self.area - 1 else C
            slice_x = x[:, start:end, :]     # (B, channels_in_area, F)
            area_mean = tf.reduce_mean(slice_x, axis=1)  # (B, F)
            data.append(area_mean)

        # stack along new area dimension
        return tf.stack(data, axis=1)        # (B, num_areas, F)


# ----------------- Local Filter Layer -----------------

class LocalFilterLayer(layers.Layer):
    """
    Implements: x = ReLU( x * W - bias )
    with W: (channels, features), bias: (1, channels, 1)
    """
    def __init__(self, num_channels, feature_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_channels = num_channels
        self.feature_dim = feature_dim

    def build(self, input_shape):
        self.local_filter_weight = self.add_weight(
            shape=(self.num_channels, self.feature_dim),
            initializer=tf.keras.initializers.GlorotUniform(),
            trainable=True,
            name="local_filter_weight"
        )
        self.local_filter_bias = self.add_weight(
            shape=(1, self.num_channels, 1),
            initializer="zeros",
            trainable=True,
            name="local_filter_bias"
        )
        super().build(input_shape)

    def call(self, x):
        # x: (B, channels, features)
        B = tf.shape(x)[0]
        w = tf.expand_dims(self.local_filter_weight, axis=0)  # (1,C,F)
        w = tf.repeat(w, repeats=B, axis=0)                   # (B,C,F)
        x = tf.nn.relu(x * w - self.local_filter_bias)        # broadcast over features axis
        return x


# ----------------- Helper: temporal output size -----------------

def compute_temporal_feature_dim(
    input_size,      # (freq, channels, time)
    sampling_rate,
    num_T,
    pool,
    pool_step_rate,
    window=(0.5, 0.25, 0.125),
):
    """
    Analytic reimplementation of get_size_temporal() from PyTorch code.
    Returns the feature dimension per channel after temporal blocks.
    """
    F, C, T = input_size
    widths = []

    for w in window:
        kernel_len = int(w * sampling_rate)               # conv kernel along time
        # Conv2d (no padding, stride=1)
        T_conv = T - kernel_len + 1

        # AvgPool2d kernel=pool, stride=pool_step_rate*pool (no padding)
        stride = int(pool_step_rate * pool)
        T_pool = math.floor((T_conv - pool) / stride + 1)
        widths.append(T_pool)

    total_w = sum(widths)

    # OneXOneConv: Conv2d with kernel=(1,1) (no change), then AvgPool2D((1,2), stride=(1,2))
    k2 = 2
    s2 = 2
    width_after = math.floor((total_w - k2) / s2 + 1)

    feature_dim = num_T * width_after
    return feature_dim


# ----------------- LGGNet Keras Model -----------------

class LGGNetKeras(Model):
    """
    Keras implementation of LGGNet.

    Expected input shape: (batch, F, C, T)
      - F: number of "frequency" bands (input_size[0] in original code)
      - C: number of EEG channels       (input_size[1])
      - T: time points                  (input_size[2])
    """

    def __init__(
        self,
        num_classes,
        input_size,        # (F, C, T)
        sampling_rate,
        num_T,
        out_graph,
        dropout_rate,
        pool,
        pool_step_rate,
        idx_graph,         # list of channels per brain area
        **kwargs
    ):
        super().__init__(**kwargs)

        self.window = [0.5, 0.25, 0.125]
        self.pool = pool
        self.channel = input_size[1]
        self.brain_area = len(idx_graph)
        self.idx_graph = idx_graph

        # Temporal learners (three Tception branches)
        self.Tception1 = self._temporal_learner(
            in_chan=input_size[0],
            out_chan=num_T,
            kernel_len=int(self.window[0] * sampling_rate),
            pool=pool,
            pool_step_rate=pool_step_rate,
            name="Tception1"
        )
        self.Tception2 = self._temporal_learner(
            in_chan=input_size[0],
            out_chan=num_T,
            kernel_len=int(self.window[1] * sampling_rate),
            pool=pool,
            pool_step_rate=pool_step_rate,
            name="Tception2"
        )
        self.Tception3 = self._temporal_learner(
            in_chan=input_size[0],
            out_chan=num_T,
            kernel_len=int(self.window[2] * sampling_rate),
            pool=pool,
            pool_step_rate=pool_step_rate,
            name="Tception3"
        )

        self.BN_t  = layers.BatchNormalization(axis=1, name="BN_t")   # channels_first
        self.BN_t_ = layers.BatchNormalization(axis=1, name="BN_t_")

        self.OneXOneConv = tf.keras.Sequential(
            [
                layers.Conv2D(
                    filters=num_T,
                    kernel_size=(1, 1),
                    strides=(1, 1),
                    padding="valid",
                    data_format="channels_first",
                    name="conv_1x1"
                ),
                layers.LeakyReLU(),
                layers.AveragePooling2D(
                    pool_size=(1, 2),
                    strides=(1, 2),
                    data_format="channels_first",
                    name="avgpool_1x2"
                )
            ],
            name="OneXOneConv"
        )

        # Compute feature dimension after temporal blocks, per channel
        feature_dim = compute_temporal_feature_dim(
            input_size=input_size,
            sampling_rate=sampling_rate,
            num_T=num_T,
            pool=pool,
            pool_step_rate=pool_step_rate,
            window=self.window
        )

        # Local filter: W and bias
        self.local_filter = LocalFilterLayer(
            num_channels=self.channel,
            feature_dim=feature_dim,
            name="LocalFilter"
        )

        # Aggregator over brain areas  ✅ FIXED LINE
        self.aggregate = Aggregator(idx_area=self.idx_graph, name="Aggregator")

        # Global adjacency (trainable)
        self.global_adj = self.add_weight(
            shape=(self.brain_area, self.brain_area),
            initializer=tf.keras.initializers.GlorotUniform(),
            trainable=True,
            name="global_adj"
        )

        # BN over brain areas
        self.bn  = layers.BatchNormalization(axis=1, name="bn_global1")
        self.bn_ = layers.BatchNormalization(axis=1, name="bn_global2")

        # GCN layer
        self.gcn = GraphConvolution(
            in_features=feature_dim,
            out_features=out_graph,
            name="GCN"
        )

        # Final classifier
        self.fc = tf.keras.Sequential(
            [
                layers.Dropout(rate=dropout_rate),
                layers.Dense(num_classes)
            ],
            name="Classifier"
        )

    def _temporal_learner(self, in_chan, out_chan, kernel_len, pool, pool_step_rate, name=None):
        step = int(pool_step_rate * pool)
        return tf.keras.Sequential(
            [
                layers.Conv2D(
                    filters=out_chan,
                    kernel_size=(1, kernel_len),
                    strides=(1, 1),
                    padding="valid",
                    data_format="channels_first",
                    use_bias=True,
                    name=f"{name}_conv" if name else None
                ),
                PowerLayer(
                    length=pool,
                    step=step,
                    name=f"{name}_power" if name else None
                )
            ],
            name=name
        )

    def _self_similarity(self, x):
        # x: (B, node, feature)
        return tf.matmul(x, x, transpose_b=True)  # (B, node, node)

    def _get_adj(self, x, self_loop=True):
        # x: (B, node, feature)
        adj = self._self_similarity(x)  # (B, N, N)

        # symmetric learned adjacency
        sym = self.global_adj + tf.transpose(self.global_adj)  # (N, N)
        adj = tf.nn.relu(adj * sym)                            # broadcast over batch

        num_nodes = tf.shape(adj)[-1]
        if self_loop:
            adj = adj + tf.eye(num_nodes)[tf.newaxis, :, :]

        rowsum = tf.reduce_sum(adj, axis=-1)          # (B, N)
        mask = tf.cast(tf.equal(rowsum, 0.0), rowsum.dtype)
        rowsum = rowsum + mask
        d_inv_sqrt = tf.pow(rowsum, -0.5)
        d_mat_inv_sqrt = tf.linalg.diag(d_inv_sqrt)   # (B, N, N)
        adj_norm = tf.matmul(tf.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
        return adj_norm  # (B, N, N)

    def call(self, x, training=False):
        """
        x: (B, F, C, T)  — frequency x channel x time
        """
        # --- Temporal blocks ---
        y1 = self.Tception1(x, training=training)
        out = y1
        y2 = self.Tception2(x, training=training)
        out = tf.concat([out, y2], axis=-1)   # concat on time axis
        y3 = self.Tception3(x, training=training)
        out = tf.concat([out, y3], axis=-1)

        out = self.BN_t(out, training=training)
        out = self.OneXOneConv(out, training=training)
        out = self.BN_t_(out, training=training)

        # Permute (B, Tchan, C, W) -> (B, C, Tchan, W)
        out = tf.transpose(out, perm=[0, 2, 1, 3])

        # Flatten last two dims: (B, C, features)
        B = tf.shape(out)[0]
        C = tf.shape(out)[1]
        Fdim = tf.shape(out)[2] * tf.shape(out)[3]
        out = tf.reshape(out, (B, C, Fdim))

        # Local filter
        out = self.local_filter(out)

        # Aggregate into brain areas: (B, brain_area, features)
        out = self.aggregate(out)

        # Build adjacency
        adj = self._get_adj(out)

        # Global graph conv
        out = self.bn(out, training=training)
        out = self.gcn([out, adj])
        out = self.bn_(out, training=training)

        # Flatten and classify
        out = tf.reshape(out, (B, -1))
        logits = self.fc(out, training=training)
        return logits


# ----------------- Helper: build + compile -----------------

def build_lggnet_keras(
    num_classes,
    input_size,        # (F, C, T)
    sampling_rate,
    num_T,
    out_graph,
    dropout_rate,
    pool,
    pool_step_rate,
    idx_graph,
    lr=1e-3
):
    model = LGGNetKeras(
        num_classes=num_classes,
        input_size=input_size,
        sampling_rate=sampling_rate,
        num_T=num_T,
        out_graph=out_graph,
        dropout_rate=dropout_rate,
        pool=pool,
        pool_step_rate=pool_step_rate,
        idx_graph=idx_graph
    )

    # Sparse categorical for multi-class (logits)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss=loss,
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    )
    return model


# ----------------- Helper: reshape to LGGNet input -----------------

def to_lggnet_input(X):
    """
    Convert your data to LGGNet format (B, F, C, T),
    here F=1 (single 'frequency' band).

    Assumes X is either:
      - (B, C, T, 1)  as in your previous pipelines
      - or (B, C, T)
    and returns (B, 1, C, T)
    """
    # If there's a trailing singleton channel dimension, drop it
    if X.ndim == 4 and X.shape[-1] == 1:
        X = np.squeeze(X, axis=-1)   # (B, C, T)

    # Add fake frequency dimension F=1 at axis=1
    X = np.expand_dims(X, axis=1)    # (B, 1, C, T)
    return X

