In [1]:
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast

from functools import partial
from typing import Optional, Tuple, Type, Dict, List

from utils import window_partition, window_unpartition, patch_partition, patch_unpartition

In [2]:
# code from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py

def init_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    freqs_x = []
    freqs_y = []
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    for i in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)        
        fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
        fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)
    return freqs

def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode='floor').float()
    return t_x, t_y

def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
    N = t_x.shape[0]
    # No float 16 for this range
    # with torch.amp.autocast(freqs.device, enabled=False):
    freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)).view(N, num_heads, -1).permute(1, 0, 2)
    freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(N, num_heads, -1).permute(1, 0, 2)
    freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]

    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]

    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

def load_tokenizer(vocab_path: str = './tokenizer_vocab.json') -> Dict[str, int]:
    """Load tokenizer dictionary from its vocab json file.  

    Args:
        vocab_path (Optional[str]): Path to the vocab json file. Defaults to './tokenizer_vocab.json'.

    Returns:
        Dict[str, int]: Dictionary mapping marker names to their respective indices.
    """
    with open(vocab_path, 'r') as f:
        tokenizer = json.load(f)
    return tokenizer


In [3]:
class MLP(nn.Module):
    """Standard MLP module"""
    def __init__(
            self, 
            embedding_dim: int,
            mlp_dim: int,
            mlp_bias: bool = True,
            act: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_dim, bias=mlp_bias),
            act(),
            nn.Linear(mlp_dim, embedding_dim, bias=mlp_bias),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


# code adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
# and https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
class SpatialAttention(nn.Module):
    """Multi-head Attention block with ROPE relative position embeddings (per channel/marker)."""

    def __init__(
        self,
        embedding_dim: int = 768,
        num_heads: int = 8,
        qkv_bias: bool = True,
        rope_theta: float = 10.,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
            rope_theta (float): Theta value for relative positional embeddings.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size. If None, use the default value of 16x16 and will be adjusted
                to the input shape.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = embedding_dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(embedding_dim, embedding_dim)

        self.compute_cis = partial(compute_mixed_cis, num_heads=self.num_heads)
            
        freqs = init_2d_freqs(
            dim=head_dim, num_heads=self.num_heads, theta=rope_theta, 
            rotate=True
        ).view(2, -1)
        self.freqs = nn.Parameter(freqs, requires_grad=True)
        
        if input_size is not None:
            end_x, end_y = input_size
        else:
            end_x, end_y = 16, 16

        t_x, t_y = init_t_xy(end_x=end_x, end_y=end_y)
        self.register_buffer('freqs_t_x', t_x)
        self.register_buffer('freqs_t_y', t_y)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape
        # qkv with shape (3, B, nHead, H * W, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

        # q, k, v with shape (B, nHead, H * W, C)
        q, k, v = qkv.unbind(0)

        # Apply rotary embeddings
        t_x, t_y = self.freqs_t_x, self.freqs_t_y
        
        if self.freqs_t_x.shape[0] != H:
            t_x, t_y = init_t_xy(end_x=W, end_y=H)
            t_x, t_y = t_x.to(x.device), t_y.to(x.device)
        freqs_cis = self.compute_cis(self.freqs, t_x, t_y)

        q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)        

        # Attention
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
        x = self.proj(x)

        return x


class SpatialAttentionBlock(nn.Module):
    """Attention block with spatial (per-marker) attention on pixel level."""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int = 8,
        mlp_dim: int = 1024,
        mlp_bias: bool = True,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        window_size: int = 0,
        window_shift: Tuple[int, int] = (0, 0),
        mini_batch: int = 0,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            window_size (int): Window size for window attention blocks. If it equals 0, then
                use global attention.
            window_shift (Tuple[int, int]): Shift for window attention blocks.
            mini_batch (int): Mini-batch size for window attention blocks. If it equals 0, then
                all windows are processed at once.
            
        """
        super().__init__()
        self.norm1 = norm_layer(embedding_dim)
        self.attn = SpatialAttention(
            embedding_dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            input_size=(window_size, window_size) if window_size > 0 else None,
        )

        self.norm2 = norm_layer(embedding_dim)
        self.mlp = MLP(embedding_dim=embedding_dim, mlp_dim=mlp_dim, mlp_bias=mlp_bias, act=act_layer)

        self.window_size = window_size
        self.window_shift = window_shift
        self.mini_batch = mini_batch

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W, E = x.shape
        x = x.reshape(B * C, H, W, E)
        
        # Window partition
        if self.window_size > 0:
            x, pad_hw = window_partition(x, self.window_size, self.window_shift) # [B * C * num_windows, window_size, window_size, E]
        
        shortcut = x
        x = self.norm1(x)

        if self.mini_batch > 0:
            # Iterate over slices of size of mini_batch
            i, j = 0, self.mini_batch
            while i < x.shape[0]:
                x[i:j] = shortcut[i:j] + self.attn(x[i:j])
                x[i:j] = x[i:j] + self.mlp(self.norm2(x[i:j]))
                i, j = j, j + self.mini_batch
        else:
            x = self.attn(x)
            x = shortcut + x
            x = x + self.mlp(self.norm2(x))

        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, self.window_shift, pad_hw, (B, C, H, W, E))

        x = x.reshape(B, C, H, W, E)

        return x


class ChannelAttention(nn.Module):
    """Multi-head Attention block (across channels/markers)."""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = embedding_dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(embedding_dim, embedding_dim)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, _ = x.shape
        
        qkv = self.qkv(x).reshape(B, C, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # qkv with shape (3, B, nHead, C, E)
        q, k, v = qkv.reshape(3, B * self.num_heads, C, -1).unbind(0) # q, k, v with shape (B * nHead, C, E)

        attn = (q * self.scale) @ k.transpose(-2, -1)

        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, C, -1).permute(0, 2, 1, 3).reshape(B, C, -1)
        x = self.proj(x)

        return x


class CrossChannelAttentionBlock(nn.Module):
    """Attention block with cross-channel (per-position) attention on patch level."""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int = 8,
        mlp_dim: int = 1024,
        mlp_bias: bool = True,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        patch_size: int = 16,
        patch_shift: Tuple[int, int] = (0, 0),
        mini_batch: int = 0,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            patch_size (int): Patch size for attention blocks.
            patch_shift (Tuple[int, int]): Shift for attention blocks.
            mini_batch (int): Mini-batch size for per-position patch blocks. If it equals 0, then
                all patches are processed at once.
        """
        super().__init__()
        self.norm1 = norm_layer(embedding_dim)
        self.attn = ChannelAttention(
            embedding_dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
        )

        self.norm2 = norm_layer(embedding_dim)
        self.mlp = MLP(embedding_dim=embedding_dim, mlp_dim=mlp_dim, mlp_bias=mlp_bias, act=act_layer)

        self.patch_size = patch_size
        # self.patch_proj = nn.Parameter(torch.randn(patch_size * patch_size, embedding_dim))
        self.patch_proj = nn.Linear(embedding_dim, embedding_dim)
        self.patch_shift = patch_shift

        self.mini_batch = mini_batch

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W, E = x.shape

        x, pad_hw = patch_partition(x, self.patch_size, self.patch_shift) # [B * num_patches, C, patch_size * patch_size, E]

        # x_patched = torch.einsum("bcpe,pe->bce", x, self.patch_proj) # [B * num_patches, C, E]
        x_patched = self.patch_proj(x).mean(dim=2) # [B * num_patches, C, E]
        x_patched = self.norm1(x_patched)

        if self.mini_batch > 0:
            # Iterate over slices of size of mini_batch
            i, j = 0, self.mini_batch
            while i < x_patched.shape[0]:
                atn = self.attn(x_patched[i:j]).unsqueeze(2) # [mini_batch, C, 1, E]
                atn = atn.expand_as(x[i:j]) # [mini_batch, C, patch_size * patch_size, E]

                x[i:j] = x[i:j] + atn
                x[i:j] = x[i:j] + self.norm2(self.mlp(x[i:j]))
                i, j = j, j + self.mini_batch
        else:
            x = x + self.attn(x_patched).unsqueeze(2).expand_as(x)
            x = x + self.mlp(self.norm2(x))

        x = patch_unpartition(x, self.patch_size, self.patch_shift, pad_hw, (B, C, H, W, E))
        return x



class MultiplexBlock(nn.Module):
    """Multiplex Block integrating spatial and cross-channel attention blocks."""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        spatial_window_size: int,
        channel_patch_size: int,
        spatial_window_shift: Tuple[int, int] = (0, 0),
        channel_patch_shift: Tuple[int, int] = (0, 0),
        mlp_dim: int = 1024,
        mlp_bias: bool = True,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        mini_batch: int = 0,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each attention block.
            spatial_window_size int: Window size for spatial attention blocks.
            channel_patch_size int: Patch size for cross-channel attention blocks.
            spatial_window_shift (Tuple[int, int]): Shift for spatial attention blocks.
            channel_patch_shift (Tuple[int, int]): Shift for cross-channel attention blocks.
            mlp_dim (int): Hidden dimension of the mlp.
            mlp_bias (bool): If True, add a learnable bias to the mlp.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            mini_batch (int): Mini-batch size for patch and windows blocks. If it equals 0, then
                all patches/windows are processed at once.
        """
        super().__init__()
        self.spatial_attn = SpatialAttentionBlock(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            mlp_bias=mlp_bias,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            act_layer=act_layer,
            window_size=spatial_window_size,
            window_shift=spatial_window_shift,
            mini_batch=mini_batch
        )

        self.channel_attn = CrossChannelAttentionBlock(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            mlp_bias=mlp_bias,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            act_layer=act_layer,
            patch_size=channel_patch_size,
            patch_shift=channel_patch_shift,
            mini_batch=mini_batch
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.spatial_attn(x)
        x = self.channel_attn(x)
        return x

        
class MultiplexImageTransformer(nn.Module):
    """Multiplex Block integrating spatial and cross-channel attention blocks."""

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        num_layers: int,
        spatial_window_sizes: List[int],
        channel_patch_sizes: List[int],
        shift_sec: bool = True,
        mlp_dim: int = 1024,
        mlp_bias: bool = True,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        mini_batch: int = 0,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            num_layers (int): Number of attention blocks.
            spatial_window_sizes (List[int]): Window sizes for spatial attention blocks.
            channel_patch_sizes (List[int]): Patch sizes for cross-channel attention blocks.
            shift_sec (bool): If True, shift patches/windows by half of their size every second layer.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            mini_batch (int): Mini-batch size for patch and windows blocks. If it equals 0, then
                all patches/windows are processed at once.
        """
        super().__init__()
        self.layers = nn.ModuleList()
        assert len(spatial_window_sizes) == num_layers, "Number of spatial window sizes must match number of layers."
        assert len(channel_patch_sizes) == num_layers, "Number of channel patch sizes must match number of layers."
        
        for i in range(num_layers):
            self.layers.append(
                MultiplexBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    spatial_window_size=spatial_window_sizes[i],
                    channel_patch_size=channel_patch_sizes[i],
                    spatial_window_shift=(spatial_window_sizes[i]//2, spatial_window_sizes[i]//2) if shift_sec and i % 2 == 1 else (0, 0),
                    channel_patch_shift=(channel_patch_sizes[i]//2, channel_patch_sizes[i]//2) if shift_sec and i % 2 == 1 else (0, 0),
                    mlp_dim=mlp_dim,
                    mlp_bias=mlp_bias,
                    qkv_bias=qkv_bias,
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    mini_batch=mini_batch
                )
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'mps'

In [5]:
# atn = SpatialAttention(768, 8)
# input = torch.rand(3, 16, 16, 768)
# output = atn(input)

In [6]:
blck = SpatialAttentionBlock(embedding_dim=512, window_size=16, mini_batch=10, window_shift=(8,8)).to(device)
input = torch.rand(1, 40, 100, 100, 512).to(device)
output = blck(input)

In [6]:
blck = CrossChannelAttentionBlock(embedding_dim=512, patch_size=16, mini_batch=0, patch_shift=(8,8)).to(device)
input = torch.rand(1, 40, 100, 100, 512).to(device)
output = blck(input)

In [5]:
blck = MultiplexBlock(embedding_dim=512, num_heads=8, spatial_window_size=8, channel_patch_size=8, mini_batch=32).to(device)
input = torch.rand(1, 40, 100, 100, 512).to(device)
output = blck(input)

In [5]:
embedding_dim = 128

model = MultiplexImageTransformer(
    embedding_dim=embedding_dim,
    num_heads=4,
    num_layers=4,
    mlp_dim=embedding_dim*2,
    spatial_window_sizes=[8, 8, 16],
    channel_patch_sizes=[1, 1, 4],
    mini_batch=1
).to(device)


input = torch.rand(1, 40, 112, 112, embedding_dim).to(device)
# with autocast(device_type=device, dtype=torch.float16):
output = model(input)

RuntimeError: MPS backend out of memory (MPS allocated: 31.62 GB, other allocations: 4.65 GB, max allowed: 36.27 GB). Tried to allocate 128.00 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).