In [1]:
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast

# from functools import partial
from typing import Optional, Tuple, Type, Dict, List

# from utils import window_partition, window_unpartition, patch_partition, patch_unpartition

In [2]:
class GlobalResponseNormalization(nn.Module):
    """Global Response Normalization (GRN) layer 
    from https://arxiv.org/pdf/2301.00808"""

    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

        self.gamma = nn.Parameter(torch.zeros(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # B, C, H, W, E = x.shape

        gx = torch.norm(x, p=2, dim=(2,3), keepdim=True)
        nx = gx / (gx.mean(dim=-1, keepdim=True) + self.eps)
        return self.gamma * (x * nx) + self.beta + x


class ConvNextBlock(nn.Module):
    """ConvNext2 block"""
    def __init__(
            self,
            dim: int,
            inter_dim: int,
    ):
            super().__init__()
            self.conv1 = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
            self.ln = nn.LayerNorm(dim)
            self.conv2 = nn.Linear(dim, inter_dim) # equivalent to nn.Conv2d(dim, inter_dim, kernel_size=1)
            self.act = nn.GELU()
            self.grn = GlobalResponseNormalization()
            self.conv3 = nn.Linear(inter_dim, dim) # equivalent to nn.Conv2d(inter_dim, dim, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # B, H, W, E = x.shape
        residual = x
        x = x.permute(0, 3, 1, 2) # [B, E, H, W]
        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1)  # [B, H, W, E]
        x = self.ln(x)
        x = self.conv2(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.conv3(x)
        x = x + residual

        return x


class MLP(nn.Module):
    """Standard MLP module"""
    def __init__(
            self, 
            embedding_dim: int,
            mlp_dim: int,
            mlp_bias: bool = True,
            act: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_dim, bias=mlp_bias),
            act(),
            nn.Linear(mlp_dim, embedding_dim, bias=mlp_bias),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


class SpatioChannelAttention(nn.Module):
    """Spatial convolutional blocks (per channel/marker)
    and cross-channel attention on pixel level."""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        num_conv_blocks: int = 1,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            num_conv_blocks (int): Number of spatial convolutional blocks.
        """
        super().__init__()
        self.num_heads = num_heads
        assert embedding_dim % num_heads == 0, "embedding_dim must be divisible by num_heads"
        head_dim = embedding_dim // num_heads
        self.scale = head_dim**-0.5

        self.spatial_conv_blocks = nn.ModuleList()
        for _ in range(num_conv_blocks):
            self.spatial_conv_blocks.append(
                ConvNextBlock(
                    dim=embedding_dim,
                    inter_dim=4*embedding_dim,
                )
            )
        
        self.qkv = nn.Linear(embedding_dim, 3*embedding_dim)

        self.proj = MLP(embedding_dim=embedding_dim, mlp_dim=2*embedding_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W, E = x.shape
        N_HEADS = self.num_heads
        shortcut = x

        # flatten on channel and batch dimension
        x = x.reshape(B * C, H, W, E) # [B * C, H, W, E]

        # Spatial convolutional blocks
        for block in self.spatial_conv_blocks:
            x = block(x)
        residual = x

        # Cross-channel attention
        x = x.reshape(B, C, H, W, E).permute(0, 2, 3, 1, 4).reshape(-1, C, E) # [B*H*W, C, E]
        qkv = self.qkv(x).reshape(B * H * W, C, 3, N_HEADS, -1).permute(2, 0, 3, 1, 4) # [3, B * H * W, C, num_heads, E/num_heads]
        q, k, v = qkv.unbind(0) # [B * H * W, C, num_heads, E/num_heads]

        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).reshape(B, H, W, C, N_HEADS, -1).permute(0, 3, 1, 2, 4, 5).reshape(B, C, H, W, E)
        x = x + residual
        x = self.proj(x)
        x = x + shortcut

        return x


class MultiplexBlock(nn.Module):
    """Multiplex Block integrating spatio-channel attention blocks."""

    def __init__(
        self,
        num_blocks: int,
        embedding_dim: int,
        num_heads: int,
    ) -> None:
        """
        Args:
            num_blocks (int): Number of attention blocks.
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each attention block.  
        """
        super().__init__()
        self.blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.blocks.append(
                SpatioChannelAttention(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                )
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for block in self.blocks:
            x = block(x)
        return x

        
class MultiplexImageTransformer(nn.Module):
    """Multiplex Image Transformer integrating spatial and cross-channel blocks."""

    def __init__(
        self,
        layers_blocks: List[int],
        embedding_dims: List[int],
        num_heads: List[int],
        channel_embedding_dim: int,
        num_channels: int,
    ) -> None:
        """
        Args:
            layers_blocks (List[int]): Number of attention blocks in each layer.
            embedding_dims (List[int]): Number of input channels in each layer.
            num_heads (List[int]): Number of attention heads in each attention block.
            channel_embedding_dim (int): Embedding dimension per channel pixel.
            num_channels (int): Maximal number of channels/markers (vocab dim).
        """
        super().__init__()
        self.channel_embedder = nn.Embedding(num_channels, channel_embedding_dim)

        # self.stem = nn.Conv2d(channel_embedding_dim, embedding_dims[0], kernel_size=4, padding=0, stride=4)
        self.poolings = nn.ModuleList()
        self.poolings.append(
            nn.Conv2d(channel_embedding_dim, embedding_dims[0], kernel_size=4, padding=0, stride=4)
        )
        for i, out_dim in enumerate(embedding_dims[1:]):
            input_dim = embedding_dims[i]
            self.poolings.append(
                nn.Conv2d(input_dim, out_dim, kernel_size=2, padding=0, stride=2)
            )
             
        self.layers = nn.ModuleList()
        for blocks, dim, heads in zip(layers_blocks, embedding_dims, num_heads):
            self.layers.append(
                MultiplexBlock(
                    num_blocks=blocks,
                    embedding_dim=dim,
                    num_heads=heads,
                )
            )
        

    def forward(self, x: torch.Tensor, channel_ids: torch.Tensor) -> torch.Tensor:
        """Forward pass of the Multiplex Image Transformer.

        Args:
            x (torch.Tensor): Multiplex images batch tensor with shape [B, C, H, W]
            channel_ids (torch.Tensor): Channel ids tensor with shape [B, C]

        Returns:
            torch.Tensor: Embedding tensor
        """
        B, C, H, W = x.shape
        E = self.channel_embedder.embedding_dim

        channel_embeds = self.channel_embedder(channel_ids)
        x = x.unsqueeze(-1)  # [B, C, H, W, 1]
        channel_embeds = channel_embeds.unsqueeze(-2).unsqueeze(-2) # [B, C, 1, 1, E]

        # channel embedding
        x = x * channel_embeds # [B, C, H, W, E]

        # poolings and blocks
        for pooling, layer in zip(self.poolings, self.layers):
            x = x.reshape(B * C, H, W, E).permute(0, 3, 1, 2)
            x = pooling(x)
            _, E, H, W = x.shape
            x = x.permute(0, 2, 3, 1).reshape(B, C, H, W, E)
            x = layer(x)

        return x


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'mps'

In [5]:
blck = SpatioChannelAttention(embedding_dim=96, num_heads=1, num_conv_blocks=3).to(device)
input = torch.rand(1, 40, 128, 128, 96).to(device)
output = blck(input)

In [4]:
blck = MultiplexBlock(num_blocks=1, embedding_dim=96, num_heads=1).to(device)
input = torch.rand(1, 40, 128, 128, 96).to(device)
output = blck(input)

In [5]:
blck = MultiplexBlock(embedding_dim=512, num_heads=8, spatial_window_size=8, channel_patch_size=8, mini_batch=32).to(device)
input = torch.rand(1, 40, 128, 128, 96).to(device)
output = blck(input)

In [5]:
layers_blocks = [3, 3, 3, 3]
embedding_dims = [96, 192, 384, 768]
num_heads = [1, 2, 4, 8]

model = MultiplexImageTransformer(
    layers_blocks=layers_blocks,
    embedding_dims=embedding_dims,
    num_heads=num_heads,
    channel_embedding_dim=48,
    num_channels=40,
).to(device)


input = torch.rand(1, 40, 224, 224).to(device)
channel_ids = torch.randint(0, 40, (1, 40)).to(device)
# with autocast(device_type=device, dtype=torch.float16):
output = model(input, channel_ids)