In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------------
# Positional Encoding
# -------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# -------------------------------
# Region Embedding (CNN + Class Token)
# -------------------------------
class RegionEmbedding(nn.Module):
    def __init__(self, in_channels, embed_dim, patch_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=(1, patch_size), stride=(1, patch_size))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_encoder = PositionalEncoding(embed_dim)

    def forward(self, x):
        # x shape: (batch, channels, electrodes, time_bins)
        batch_size = x.size(0)
        x = self.conv(x)  # (batch, embed_dim, electrodes, time_bins_patch)
        x = x.flatten(2).transpose(1, 2)  # (batch, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # (batch, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # prepend class token
        x = self.pos_encoder(x)
        return x

# -------------------------------
# Multi-Head Self-Attention Block
# -------------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, 2*embed_dim),
            nn.GELU(),
            nn.Linear(2*embed_dim, embed_dim)
        )

    def forward(self, x, kv=None):
        # x: query (batch, seq_len, embed_dim)
        # kv: key/value (batch, seq_len_kv, embed_dim) for cross-attention
        if kv is None:
            kv = x
        attn_out, _ = self.attn(x, kv, kv)
        x = self.norm(x + attn_out)
        x = self.norm(x + self.ff(x))
        return x

# -------------------------------
# Region-wise Self-Attention (RSA)
# -------------------------------
class RSA(nn.Module):
    def __init__(self, embed_dim, num_heads, depth):
        super().__init__()
        self.layers = nn.ModuleList([MultiHeadAttention(embed_dim, num_heads) for _ in range(depth)])

    def forward(self, x):
        # x: (batch, num_patches+1, embed_dim)
        for layer in self.layers:
            x = layer(x)
        # return only the class token
        return x[:, 0, :]  # shape: (batch, embed_dim)

# -------------------------------
# Cross-Region Attention (CRA)
# -------------------------------
class CRA(nn.Module):
    def __init__(self, embed_dim, num_heads, depth):
        super().__init__()
        self.layers = nn.ModuleList([MultiHeadAttention(embed_dim, num_heads) for _ in range(depth)])

    def forward(self, class_tokens):
        # class_tokens: list of class tokens from each region, each shape (batch, embed_dim)
        x = torch.stack(class_tokens, dim=1)  # (batch, num_regions, embed_dim)
        for layer in self.layers:
            x = layer(x)
        return [x[:, i, :] for i in range(x.size(1))]  # return list of updated tokens per region

# -------------------------------
# Combiner
# -------------------------------
class Combiner(nn.Module):
    def forward(self, class_tokens):
        # average all region tokens
        x = torch.stack(class_tokens, dim=0).mean(dim=0)  # (batch, embed_dim)
        return x

# -------------------------------
# Full Model
# -------------------------------
class IEegTransformer(nn.Module):
    def __init__(self, regions_config, embed_dim=396, num_heads=6, depth_rsa=6, depth_cra=6, patch_size=5, num_classes=8):
        """
        regions_config: list of tuples (in_channels, num_electrodes)
        """
        super().__init__()
        self.num_regions = len(regions_config)
        self.region_embeddings = nn.ModuleList([RegionEmbedding(in_ch, embed_dim, patch_size)
                                                for in_ch, _ in regions_config])
        self.rsa = nn.ModuleList([RSA(embed_dim, num_heads, depth_rsa) for _ in regions_config])
        self.cra = CRA(embed_dim, num_heads, depth_cra)
        self.combiner = Combiner()
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x_list):
        """
        x_list: list of tensors per region, each shape (batch, channels, electrodes, time_bins)
        """
        region_tokens = []
        for i in range(self.num_regions):
            x = self.region_embeddings[i](x_list[i])
            token = self.rsa[i](x)  # RSA output class token
            region_tokens.append(token)
        region_tokens = self.cra(region_tokens)  # CRA updates each region token
        unified_repr = self.combiner(region_tokens)  # average over regions
        out = self.classifier(unified_repr)  # logits for 8 concepts
        return torch.sigmoid(out)  # BCE probabilities



ModuleNotFoundError: No module named 'torch'

In [None]:
# -------------------------------
# Example Usage
# -------------------------------
if __name__ == "__main__":
    batch_size = 4
    time_bins = 50
    regions_config = [(2, 8), (2, 8), (2, 8)]  # example 3 regions with 8 electrodes each, 2 channels per region

    # random input tensors per region
    x_list = [torch.randn(batch_size, in_ch, n_e, time_bins) for in_ch, n_e in regions_config]

    model = IEegTransformer(regions_config)
    output = model(x_list)
    print(output.shape)  # (batch_size, 8)
