In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MBConv(nn.Module):
    """
    A simple MobileNet-style inverted residual block (MBConv).
    This is used within the MaxViT blocks.
    """
    def __init__(self, in_channels, out_channels, stride=1, expand_ratio=4):
        super(MBConv, self).__init__()
        self.stride = stride
        hidden_dim = in_channels * expand_ratio

        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        layers = []
        # Expansion phase
        if expand_ratio != 1:
            layers.append(nn.Conv1d(in_channels, hidden_dim, kernel_size=1, bias=False))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())

        # Depthwise convolution
        layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False))
        layers.append(nn.BatchNorm1d(hidden_dim))
        layers.append(nn.GELU())

        # Projection phase
        layers.append(nn.Conv1d(hidden_dim, out_channels, kernel_size=1, bias=False))
        layers.append(nn.BatchNorm1d(out_channels))

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MaxViTBlock(nn.Module):
    """
    A simplified MaxViT block for 1D time-series data.
    It alternates between local (block) and global (grid) attention.
    """
    def __init__(self, dim, num_heads, block_size=16, grid_size=8):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.block_size = block_size
        self.grid_size = grid_size

        self.mb_conv = MBConv(dim, dim)

        # Local Attention
        self.local_attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)

        # Global Attention
        self.global_attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # x is expected to be (batch, sequence_length, features)
        B, L, C = x.shape

        # MBConv part
        x_conv = x.permute(0, 2, 1) # (B, C, L)
        x_conv = self.mb_conv(x_conv)
        x_conv = x_conv.permute(0, 2, 1) # (B, L, C)
        x = x + x_conv

        # --- Local Attention ---
        # Pad sequence to be divisible by block_size
        pad_len = (self.block_size - L % self.block_size) % self.block_size
        x_padded = F.pad(x, (0, 0, 0, pad_len))

        # Reshape for block-wise attention
        num_blocks = x_padded.shape[1] // self.block_size
        x_blocks = x_padded.reshape(B * num_blocks, self.block_size, C)

        # Apply local attention
        local_attn_out, _ = self.local_attention(x_blocks, x_blocks, x_blocks)
        local_attn_out = local_attn_out.reshape(B, num_blocks * self.block_size, C)

        # Remove padding
        x_local = local_attn_out[:, :L, :]
        x = self.norm1(x + x_local)

        # --- Global Attention ---
        # Downsample for grid attention (sparse attention)
        x_grid = F.adaptive_avg_pool1d(x.permute(0, 2, 1), self.grid_size).permute(0, 2, 1)

        # Apply global attention
        global_attn_out, _ = self.global_attention(x_grid, x_grid, x_grid)

        # Upsample back to original length
        global_attn_out = F.interpolate(global_attn_out.permute(0, 2, 1), size=L, mode='linear').permute(0, 2, 1)

        x = self.norm2(x + global_attn_out)

        return x

class MaxViTFeatureExtractor(nn.Module):
    """
    The main MaxViT-inspired feature extractor for time-series.
    """
    def __init__(self, input_dim, embed_dim, num_heads, num_blocks, sequence_length):
        super().__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, sequence_length, embed_dim))

        self.blocks = nn.ModuleList([
            MaxViTBlock(dim=embed_dim, num_heads=num_heads) for _ in range(num_blocks)
        ])

        self.pool = nn.AdaptiveAvgPool1d(1)
        self.output_layer = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_dim)
        x = self.embedding(x)
        x = x + self.pos_embedding

        for block in self.blocks:
            x = block(x)

        # Global average pooling
        x = x.permute(0, 2, 1) # (batch, embed_dim, seq_len)
        x = self.pool(x).squeeze(-1) # (batch, embed_dim)

        x = self.output_layer(x)
        return x

if __name__ == '__main__':
    # Example usage
    BATCH_SIZE = 4
    SEQ_LEN = 60 # e.g., 60 days of data
    INPUT_DIM = 10 # Number of features (Open, High, Low, Close, Volume, etc.)
    EMBED_DIM = 64
    NUM_HEADS = 4
    NUM_BLOCKS = 2

    model = MaxViTFeatureExtractor(
        input_dim=INPUT_DIM,
        embed_dim=EMBED_DIM,
        num_heads=NUM_HEADS,
        num_blocks=NUM_BLOCKS,
        sequence_length=SEQ_LEN
    )

    # Create a dummy input tensor
    dummy_input = torch.randn(BATCH_SIZE, SEQ_LEN, INPUT_DIM)

    # Get the output features
    output_features = model(dummy_input)

    print("--- MaxViT Feature Extractor ---")
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output features shape: {output_features.shape}") # Should be (BATCH_SIZE, EMBED_DIM)

--- MaxViT Feature Extractor ---
Input shape: torch.Size([4, 60, 10])
Output features shape: torch.Size([4, 64])
