In [19]:
import torch
import torch.nn as nn
from functools import partial

import sys
sys.path.append("/home/maxihuber/eeg-foundation")

from src.utils.rope_utils import random_masking_smart
from src.models.components.vit_rope import select_freqs_cis
from timm.models.vision_transformer import Mlp as Mlp
from src.models.mae_rope_encoder import EncoderViTRoPE
from src.models.mae_rope_decoder import DecoderViTRoPE

In [451]:
from src.models.components.models_v2 import (
    vit_models,
    Layer_scale_init_Block,
    Attention,
)

class Flexible_RoPE_Layer_scale_init_Block(Layer_scale_init_Block):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications

    # Adjusted to work with FlexibleRoPEAttention.

    def __init__(self, *args, **kwargs):
        kwargs["Attention_block"] = FlexibleRoPEAttention
        super().__init__(*args, **kwargs)

    def forward(self, x, freqs_cis, mask, nr_meta_tokens):
        x = x + self.drop_path(
            self.gamma_1
            * self.attn(
                self.norm1(x), freqs_cis=freqs_cis, mask=mask, nr_meta_tokens=nr_meta_tokens
            )
        )
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))

        return x

class FlexibleRoPEAttention(Attention):
    """
    Multi-head Attention block with rotary position embeddings.

    Adjusted the RoPEAttention class to work with a variable number of prepended tokens,
    e.g. one cls token and multiple mean tokens.

    They are not taken into consideration when applying the rotary position embeddings.
    """

    def forward(self, x, freqs_cis, mask, nr_meta_tokens):
        B, N, C = x.shape

        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        q[:, :, nr_meta_tokens:], k[:, :, nr_meta_tokens:] = apply_rotary_emb(
            q[:, :, nr_meta_tokens:], k[:, :, nr_meta_tokens:], freqs_cis=freqs_cis
        )
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    print(f"[apply_rotary_emb] xq_.shape: {xq_.shape}")
    print(f"[apply_rotary_emb] freqs_cis.shape: {freqs_cis.shape}")
    xq_prod = xq_ * freqs_cis
    print(f"[apply_rotary_emb] (xq_ * freqs_cis).shape: {xq_prod.shape}")
    xq_out = torch.view_as_real(xq_prod).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    print("[reshape_for_broadcast] freqs_cis.shape:", freqs_cis.shape)
    print("[reshape_for_broadcast] xq_.shape:", x.shape)
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[0], x.shape[-2], x.shape[-1]): #custom case
        shape = [x.shape[0], 1, x.shape[-2], x.shape[-1]] 

    return freqs_cis.view(*shape)

def random_masking_smart(x, mask_ratio, nr_meta_tokens):
    B, N, D = x.shape

    num_tokens_to_keep = int((N - nr_meta_tokens) * (1 - mask_ratio))

    # indices we keep
    rand_indices, _ = (
        torch.rand(B, N - nr_meta_tokens, device=x.device)
        .argsort(dim=1)[:, :num_tokens_to_keep]
        .sort(dim=1)
    )
    rand_indices += nr_meta_tokens

    # Add True values at positions we keep
    mask = torch.zeros(B, N, dtype=torch.bool, device=x.device)
    mask.scatter_(1, rand_indices, True)

    # Fill mask[:, :nr_meta_tokens] with True (always keep metadata tokens)
    mask[:, :nr_meta_tokens] = True

    # Indices to restore in decoder
    kept_indices = torch.nonzero(mask, as_tuple=True)[1].reshape(B, -1)
    masked_indices = torch.nonzero(~mask, as_tuple=True)[1].reshape(B, -1)
    ids_restore = torch.cat([kept_indices, masked_indices], dim=1).unsqueeze(-1).repeat(1, 1, D)
    
    return mask, ids_restore

def random_masking(x, mask_ratio, nr_meta_tokens):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    B, N, D = x.shape  # batch, length, dim
    len_keep = int((N - nr_meta_tokens) * (1 - mask_ratio))
    
    noise = torch.rand(B, N - nr_meta_tokens, device=x.device)  # noise in [0, 1]
    
    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_shuffle += nr_meta_tokens
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    print("ids_restore.shape", ids_restore.shape)
    print(ids_restore[0])

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([B, N], device=x.device)
    mask[:, :len_keep] = 0
    print("mask.shape", mask.shape)
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)
    print("mask.shape", mask.shape)
    
    #return x_masked, mask, ids_restore
    return mask, ids_restore

In [140]:
B, C, H, W = 4, 1, 32, 2048
win_size = 1
nr_meta_patches = 2
mask_ratio = 0.75

x = torch.randn(B, h * w + nr_meta_patches, 384)

def encoder(x, win_size, mask_ratio):

    self = EncoderViTRoPE(channel_names_path="/home/maxihuber/eeg-foundation/src/data/components/channels_to_id.json")
    
    B, C, D = x.shape
    h, w = H // 16, W // 16
    
    # Keep for reconstruction loss
    meta_patches = x[:, :nr_meta_patches, :]
    
    # Encoder: randomly mask some patches (exluding metadata patches)
    x, mask = random_masking_smart(
        x=x, mask_ratio=mask_ratio, nr_meta_tokens=nr_meta_patches
    )
    x = x[mask].view(B, -1, x.shape[-1])
    print("[forward_encoder] after random_masking_smart (x.shape):", x.shape, "(B, N, D)")
    print("[forward_encoder] after random_masking_smart (mask.shape):", mask.shape, "(B, N)")
    
    # Encoder: select correct rotation information for the attention layers
    freqs_cis = select_freqs_cis(
        self, self.encoder_freqs_cis, H, W, win_size, x.device
    ).unsqueeze(0).repeat(B, 1, 1)
    freqs_cis = freqs_cis[mask[:, nr_meta_patches:]].view(B, -1, freqs_cis.shape[-1])
    print("[forward_encoder] freqs_cis.shape:", freqs_cis.shape, "(B, N, D // num_heads // 2)")
    
    blk = Flexible_RoPE_Layer_scale_init_Block(
        dim=384,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        act_layer=nn.GELU,
        Attention_block=FlexibleRoPEAttention,
        Mlp_block=Mlp,
        init_values=1e-4,
    )
    x = blk(x, freqs_cis=freqs_cis, mask=mask, nr_meta_tokens=nr_meta_patches)
    print("[forward_encoder] after rope blocks:", x.shape, "(B, N, D)")

    return x, meta_patches, mask, nr_meta_patches

def decoder(x, nr_meta_patches, H, W, win_size):
    
    self = DecoderViTRoPE(channel_names_path="/home/maxihuber/eeg-foundation/src/data/components/channels_to_id.json")

    x = torch.randn(x.shape[-3], x.shape[-2], 512)
    B, N, D = x.shape

    freqs_cis = select_freqs_cis(
        self, self.decoder_freqs_cis, H, W, win_size, x.device
    )

    # Insert self.mask_token at masked positions
   
    blk = Flexible_RoPE_Layer_scale_init_Block(
        dim=512,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        act_layer=nn.GELU,
        Attention_block=FlexibleRoPEAttention,
        Mlp_block=Mlp,
        init_values=1e-4,
    )

print("=" * 2 + "encoder pass" + "=" * 100)
x_emb, meta_patches, mask, nr_meta_patches = encoder(x=x, win_size=win_size, mask_ratio=mask_ratio)

print("=" * 2 + "decoder pass" + "=" * 100)
x_pred = decoder(x=x_emb, nr_meta_patches=nr_meta_patches, H=H, W=W, win_size=win_size)

[forward_encoder] after random_masking_smart (x.shape): torch.Size([4, 66, 384]) (B, N, D)
[forward_encoder] after random_masking_smart (mask.shape): torch.Size([4, 258]) (B, N)
[forward_encoder] freqs_cis.shape: torch.Size([4, 64, 32]) (B, N, D // num_heads // 2)
[reshape_for_broadcast] freqs_cis.shape: torch.Size([4, 64, 32])
[reshape_for_broadcast] xq_.shape: torch.Size([4, 6, 64, 32])
[apply_rotary_emb] xq_.shape: torch.Size([4, 6, 64, 32])
[apply_rotary_emb] freqs_cis.shape: torch.Size([4, 1, 64, 32])
[apply_rotary_emb] (xq_ * freqs_cis).shape: torch.Size([4, 6, 64, 32])
[forward_encoder] after rope blocks: torch.Size([4, 66, 384]) (B, N, D)


In [450]:
B, N, D = 4, 258, 10
x = torch.arange(B * N * D).reshape(B, N, D)
#print(x[0])
x_org = x.clone()

mask, ids_restore = random_masking_smart(x, mask_ratio, nr_meta_patches)

# == Encoder ==

x = x[mask].view(B, -1, D)
#print(x[0])
#print("x_masked.shape:", x.shape)

# == Decoder ==

# in the decoder, we need to fill the masked tokens first
mask_tokens = x_org[~mask].view(B, -1, D)
#mask_token = nn.Parameter(torch.zeros(1, 1, D))
#mask_tokens = mask_token.repeat(B, mask.shape[1] + nr_meta_patches - ids_kept.shape[1], 1)
x = torch.cat([x, mask_tokens], dim=1)

# rearrange
x_reordered = torch.zeros_like(x)
x = x_reordered.scatter_(1, ids_restore, x)

print("restored x.shape:", x.shape)
print("diff:", (x - x_org)[mask].view(B, -1, D)[0]) # not zero :(

x.shape: torch.Size([4, 258, 10])
mask.shape: torch.Size([4, 258])
x: tensor([[   0,    1,    2,    3,    4,    5,    6,    7,    8,    9],
        [  10,   11,   12,   13,   14,   15,   16,   17,   18,   19],
        [  30,   31,   32,   33,   34,   35,   36,   37,   38,   39],
        [  60,   61,   62,   63,   64,   65,   66,   67,   68,   69],
        [ 150,  151,  152,  153,  154,  155,  156,  157,  158,  159],
        [ 200,  201,  202,  203,  204,  205,  206,  207,  208,  209],
        [ 270,  271,  272,  273,  274,  275,  276,  277,  278,  279],
        [ 290,  291,  292,  293,  294,  295,  296,  297,  298,  299],
        [ 330,  331,  332,  333,  334,  335,  336,  337,  338,  339],
        [ 340,  341,  342,  343,  344,  345,  346,  347,  348,  349],
        [ 360,  361,  362,  363,  364,  365,  366,  367,  368,  369],
        [ 480,  481,  482,  483,  484,  485,  486,  487,  488,  489],
        [ 590,  591,  592,  593,  594,  595,  596,  597,  598,  599],
        [ 600,  601,

In [None]:
mask_token = nn.Parameter(torch.zeros(1, 1, D))
mask_tokens_real = mask_token.repeat(B, mask.shape[1] - kept_indices.shape[1], 1)

In [414]:
import torch

# Example setup
B, N, D = 4, 258, 10

# Create a tensor with a whole number range for easier tracking
x = torch.arange(B * N * D).reshape(B, N, D)
print("Original tensor x:")
print(x)

# Example ids_restore (ensuring ids_restore is a valid permutation of indices for each row)
ids_restore = torch.randperm(N).repeat(B, 1)
print("ids_restore:")
print(ids_restore.shape)

# Expand ids_restore to match the shape of x
ids_restore_expanded = ids_restore.unsqueeze(-1).expand(-1, -1, D)
print("Expanded ids_restore shape:", ids_restore_expanded.shape)

# Rearrange x according to ids_restore
x_reordered = torch.gather(x, dim=1, index=ids_restore_expanded)
print("Reordered tensor x_reordered:")
print(x_reordered)

# Verify that the reordered tensor matches the expected output
# (manually verify or write additional checks to confirm correctness)


Original tensor x:
tensor([[[    0,     1,     2,  ...,     7,     8,     9],
         [   10,    11,    12,  ...,    17,    18,    19],
         [   20,    21,    22,  ...,    27,    28,    29],
         ...,
         [ 2550,  2551,  2552,  ...,  2557,  2558,  2559],
         [ 2560,  2561,  2562,  ...,  2567,  2568,  2569],
         [ 2570,  2571,  2572,  ...,  2577,  2578,  2579]],

        [[ 2580,  2581,  2582,  ...,  2587,  2588,  2589],
         [ 2590,  2591,  2592,  ...,  2597,  2598,  2599],
         [ 2600,  2601,  2602,  ...,  2607,  2608,  2609],
         ...,
         [ 5130,  5131,  5132,  ...,  5137,  5138,  5139],
         [ 5140,  5141,  5142,  ...,  5147,  5148,  5149],
         [ 5150,  5151,  5152,  ...,  5157,  5158,  5159]],

        [[ 5160,  5161,  5162,  ...,  5167,  5168,  5169],
         [ 5170,  5171,  5172,  ...,  5177,  5178,  5179],
         [ 5180,  5181,  5182,  ...,  5187,  5188,  5189],
         ...,
         [ 7710,  7711,  7712,  ...,  7717,  7718,