In [7]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.models.video import swin_transformer
from torch.utils.data import DataLoader, Dataset
import numpy as np

class DummyVideoDataset(Dataset):
    def __init__(self, num_samples=10, channels=3, frames=16, height=224, width=224):
        self.num_samples = num_samples
        self.shape = (channels, frames, height, width)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        video = torch.rand(self.shape)  # shape = (3, 16, 224, 224)
        return video

class MAEPretrainSwin(pl.LightningModule):
    def __init__(self, lr=1e-4, mask_ratio=0.75):
        super().__init__()
        self.save_hyperparameters()

        self.encoder = swin_transformer.swin3d_t(weights=None)
        self.encoder.head = nn.Identity()

        # Feature hook (grab feature before classification)
        self.features = None
        self.encoder.features[4].register_forward_hook(self._hook)

        self.decoder = nn.Sequential(
            nn.Conv3d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(256, 3, kernel_size=1)
        )

        self.criterion = nn.MSELoss()

    def _hook(self, module, input, output):
        # Save the last feature map for reconstruction
        self.features = output  # shape: (B, T', H', W', C)

    def forward(self, x):
        # x: (B, 3, T, H, W)
        print("x",x.shape)
        k = self.encoder(x)  # Just run encoder to trigger the hook
        print("k",k.shape)
        feat = self.features  # (B, T', H', W', C)
        print("feat",feat.shape)

        feat = feat.permute(0, 4, 1, 2, 3)  # → (B, C, T', H', W')
        recon = self.decoder(feat)
        recon = nn.functional.interpolate(recon, size=x.shape[2:], mode='trilinear', align_corners=False)
        print("recon",recon.shape)  # Should match input shape (B, 3, T, H, W)
        return recon

    def training_step(self, batch, batch_idx):
        x = batch  # (B, 3, T, H, W)
        recon = self(x)
        loss = self.criterion(recon, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def random_masking(self, x, ratio):
        B, T, C, H, W = x.shape
        x = x.view(B, T * H * W, C)
        N = x.shape[1]
        len_keep = int(N * (1 - ratio))

        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_keep = ids_shuffle[:, :len_keep]

        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, C))
        return x_masked, ids_shuffle

# Training Script
# -----------------------------
if __name__ == "__main__":
    dataset = DummyVideoDataset(num_samples=100, frames=16)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

    model = MAEPretrainSwin()

    trainer = pl.Trainer(
        max_epochs=1,
        accelerator='auto',
        log_every_n_steps=1,
    )

    trainer.fit(model, dataloader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | encoder   | SwinTransformer3d | 27.9 M | train
1 | decoder   | Sequential        | 2.7 M  | train
2 | criterion | MSELoss           | 0      | train
--------------------------------------------------------
30.5 M    Trainable params
0         Non-trainable params
30.5 M    Total param

Epoch 0:   0%|          | 0/50 [00:00<?, ?it/s] x torch.Size([2, 3, 16, 224, 224])
k torch.Size([2, 768])
feat torch.Size([2, 8, 14, 14, 384])
recon torch.Size([2, 3, 16, 224, 224])
Epoch 0:   2%|▏         | 1/50 [00:00<00:04, 10.35it/s, v_num=6]x torch.Size([2, 3, 16, 224, 224])
k torch.Size([2, 768])
feat torch.Size([2, 8, 14, 14, 384])
recon torch.Size([2, 3, 16, 224, 224])
Epoch 0:   4%|▍         | 2/50 [00:00<00:06,  7.80it/s, v_num=6]x torch.Size([2, 3, 16, 224, 224])
k torch.Size([2, 768])
feat torch.Size([2, 8, 14, 14, 384])
recon torch.Size([2, 3, 16, 224, 224])
Epoch 0:   6%|▌         | 3/50 [00:00<00:06,  6.79it/s, v_num=6]x torch.Size([2, 3, 16, 224, 224])
k torch.Size([2, 768])
feat torch.Size([2, 8, 14, 14, 384])
recon torch.Size([2, 3, 16, 224, 224])
Epoch 0:   8%|▊         | 4/50 [00:00<00:07,  6.45it/s, v_num=6]x torch.Size([2, 3, 16, 224, 224])
k torch.Size([2, 768])
feat torch.Size([2, 8, 14, 14, 384])
recon torch.Size([2, 3, 16, 224, 224])
Epoch 0:  10%|█         | 

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 50/50 [00:09<00:00,  5.17it/s, v_num=6]


In [20]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.models.video import swin_transformer
from torch.utils.data import DataLoader, Dataset

class DummyVideoDataset(Dataset):
    def __init__(self, num_samples=10, channels=3, frames=16, height=224, width=224):
        self.num_samples = num_samples
        self.shape = (channels, frames, height, width)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        video = torch.rand(self.shape)  # shape = (3, 16, 224, 224)
        return video

def mask_video_patches(x, patch_size=(2,16,16), mask_ratio=0.75):
    """
    Zero out random patches in the video tensor.

    Args:
        x: (B, C, T, H, W)
        patch_size: Tuple of (patch_temporal, patch_height, patch_width)
        mask_ratio: fraction of patches to mask (zero out)
    Returns:
        x_masked: tensor with masked patches zeroed out
        mask: binary mask of patches (1=visible, 0=masked)
    """
    B, C, T, H, W = x.shape
    pt, ph, pw = patch_size
    assert T % pt == 0 and H % ph == 0 and W % pw == 0, "Video dimensions must be divisible by patch size"
    nt, nh, nw = T // pt, H // ph, W // pw
    num_patches = nt * nh * nw
    len_keep = int(num_patches * (1 - mask_ratio))

    # Random shuffle patches per sample
    noise = torch.rand(B, num_patches, device=x.device)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_keep = ids_shuffle[:, :len_keep]

    # Create mask tensor: 1 for keep, 0 for masked patches
    mask = torch.zeros(B, num_patches, device=x.device)
    mask.scatter_(1, ids_keep, 1)

    # Reshape mask to patches grid
    mask = mask.view(B, nt, nh, nw, 1, 1, 1)
    mask = mask.expand(-1, -1, -1, -1, pt, ph, pw)
    mask = mask.reshape(B, 1, T, H, W)  # (B, 1, T, H, W)

    # Zero out masked patches in input video tensor
    x_masked = x * mask

    return x_masked, mask

class MAEPretrainSwin(pl.LightningModule):
    def __init__(self, lr=1e-4, mask_ratio=0.75):
        super().__init__()
        self.save_hyperparameters()

        self.encoder = swin_transformer.swin3d_t(weights=None)
        self.encoder.head = nn.Identity()

        # Feature hook (grab feature before classification)
        self.features = None
        self.encoder.features[4].register_forward_hook(self._hook)

        self.decoder = nn.Sequential(
            nn.Conv3d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(256, 3, kernel_size=1)
        )

        self.criterion = nn.MSELoss()

    def _hook(self, module, input, output):
        # Save the last feature map for reconstruction
        self.features = output  # shape: (B, T', H', W', C)

    def forward(self, x):
        # x: (B, 3, T, H, W)
        x_masked, mask = mask_video_patches(x, patch_size=(2,16,16), mask_ratio=self.hparams.mask_ratio)
        print("Input shape:", x.shape)
        print("Masked input shape:", x_masked.shape)
        k = self.encoder(x_masked)  # Run encoder on masked input
        feat = self.features  # (B, T', H', W', C)

        feat = feat.permute(0, 4, 1, 2, 3)  # → (B, C, T', H', W')
        recon = self.decoder(feat)
        recon = nn.functional.interpolate(recon, size=x.shape[2:], mode='trilinear', align_corners=False)
        print("Reconstruction shape:", recon.shape)
        return recon

    def training_step(self, batch, batch_idx):
        x = batch  # (B, 3, T, H, W)
        recon = self(x)
        loss = self.criterion(recon, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

if __name__ == "__main__":
    dataset = DummyVideoDataset(num_samples=100, frames=16)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

    model = MAEPretrainSwin(mask_ratio=0.75)

    trainer = pl.Trainer(
        max_epochs=1,
        accelerator='auto',
        log_every_n_steps=1,
    )

    trainer.fit(model, dataloader)


Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | encoder   | SwinTransformer3d | 27.9 M | train
1 | decoder   | Sequential        | 2.7 M  | train
2 | criterion | MSELoss           | 0      | train
--------------------------------------------------------
30.5 M    Trainable params
0         Non-trainable params
30.5 M    Total param

Epoch 0:   0%|          | 0/50 [00:00<?, ?it/s] Input shape: torch.Size([2, 3, 16, 224, 224])
Masked input shape: torch.Size([2, 3, 16, 224, 224])
Reconstruction shape: torch.Size([2, 3, 16, 224, 224])
Epoch 0:   2%|▏         | 1/50 [00:00<00:05,  9.73it/s, v_num=19]Input shape: torch.Size([2, 3, 16, 224, 224])
Masked input shape: torch.Size([2, 3, 16, 224, 224])
Reconstruction shape: torch.Size([2, 3, 16, 224, 224])
Epoch 0:   4%|▍         | 2/50 [00:00<00:06,  7.87it/s, v_num=19]Input shape: torch.Size([2, 3, 16, 224, 224])
Masked input shape: torch.Size([2, 3, 16, 224, 224])
Reconstruction shape: torch.Size([2, 3, 16, 224, 224])
Epoch 0:   6%|▌         | 3/50 [00:00<00:06,  7.31it/s, v_num=19]Input shape: torch.Size([2, 3, 16, 224, 224])
Masked input shape: torch.Size([2, 3, 16, 224, 224])
Reconstruction shape: torch.Size([2, 3, 16, 224, 224])
Epoch 0:   8%|▊         | 4/50 [00:00<00:06,  6.94it/s, v_num=19]Input shape: torch.Size([2, 3, 16, 224, 224])
Masked input shape: torch.Size

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 50/50 [00:09<00:00,  5.25it/s, v_num=19]


# HERE

In [43]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.models.video import swin_transformer
from torch.utils.data import DataLoader, Dataset
import utils
import albumentations as A
from albumentations.pytorch import ToTensorV2


class CFG:
    # ============== comp exp name =============
    current_dir = './'
    segment_path = './train_scrolls/'
    
    start_idx = 24
    in_chans = 16
    
    size = 224
    tile_size = 224
    stride = tile_size // 8 
    
    train_batch_size =  10 # 32
    valid_batch_size = 10
    
    lr = 1e-4
    num_workers = 8
    # ============== model cfg =============
    scheduler = 'linear' # 'cosine', 'linear'
    epochs = 30
    warmup_factor = 10
    
    # Size of fragments
    frags_ratio1 = ["rem",'rect','frag']
    frags_ratio2 = ['nothing']
    ratio1 = 2
    ratio2 = 1
    
    # ============== fold =============
    segments = ['frag5',"rect5"] 
    valid_id = 'frag5'
    # ============== fixed =============
    min_lr = 1e-7
    weight_decay = 1e-6
    max_grad_norm = 100
    num_workers = 8
    seed = 0
    
        # ============== augmentation =============
    train_aug_list = [
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(rotate_limit=360,shift_limit=0.15,scale_limit=0.15,p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.CoarseDropout(max_holes=2, max_width=int(size * 0.2), max_height=int(size * 0.2), 
                        mask_fill_value=0, p=0.5),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),  
    ]

train_images, train_masks, valid_images, valid_masks, valid_xyxys = utils.get_train_valid_dataset(CFG)

reading frag5


100%|██████████| 16/16 [00:01<00:00, 11.59it/s]


 Shape of frag5 segment: (3696, 2352, 16)
(3696, 2352)


In [44]:
def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)
    return aug  
from models import swin
train_dataset = swin.TimesformerDataset(
    valid_images[:100], CFG, labels=valid_masks, transform=get_transforms(data='valid', cfg=CFG))

train_loader = DataLoader(train_dataset,
                            batch_size=CFG.train_batch_size,
                            shuffle=True,
                            num_workers=CFG.num_workers, pin_memory=True, drop_last=True,
                            )

In [None]:
# class DummyVideoDataset(Dataset):
#     def __init__(self, num_samples=10, channels=3, frames=16, height=224, width=224):
#         self.num_samples = num_samples
#         self.shape = (channels, frames, height, width)

#     def __len__(self):
#         return self.num_samples

#     def __getitem__(self, idx):
#         video = torch.rand(self.shape)  # (C, T, H, W)
#         return video

def mask_video_patches(x, patch_size=(2,16,16), mask_ratio=0.75):
    B, C, T, H, W = x.shape
    pt, ph, pw = patch_size
    assert T % pt == 0 and H % ph == 0 and W % pw == 0, "Video must divide evenly by patch size"
    nt, nh, nw = T // pt, H // ph, W // pw
    num_patches = nt * nh * nw
    len_keep = int(num_patches * (1 - mask_ratio))

    noise = torch.rand(B, num_patches, device=x.device)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_keep = ids_shuffle[:, :len_keep]

    mask = torch.zeros(B, num_patches, device=x.device)
    mask.scatter_(1, ids_keep, 1)

    mask = mask.view(B, nt, nh, nw, 1, 1, 1).expand(-1, -1, -1, -1, pt, ph, pw)
    mask = mask.reshape(B, 1, T, H, W)
    x_masked = x * mask
    return x_masked

class MAEPretrainSwin(pl.LightningModule):
    def __init__(self, lr=1e-5, mask_ratio=0.75):
        super().__init__()
        self.save_hyperparameters()

        self.encoder = swin_transformer.swin3d_t(weights="KINETICS400_V1")
        self.encoder.head = nn.Identity()

        self.features = None
        self.encoder.features[-1].register_forward_hook(self._hook)

        self.decoder = nn.Sequential(
            nn.Conv3d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(256, 3, kernel_size=1)
        )

        self.criterion = nn.MSELoss()

    def _hook(self, module, input, output):
        self.features = output  # (B, T', H', W', C)

    def forward(self, x):
        x_masked = mask_video_patches(x, patch_size=(2,16,16), mask_ratio=self.hparams.mask_ratio)
        _ = self.encoder(x_masked)
        feat = self.features.permute(0, 4, 1, 2, 3)  # (B, C, T', H', W')
        recon = self.decoder(feat)
        recon = nn.functional.interpolate(recon, size=x.shape[2:], mode='trilinear', align_corners=False)
        return recon

    def training_step(self, batch, batch_idx):
        x = batch  # (B, 3, T, H, W)
        recon = self(x)
        loss = self.criterion(recon, x)
        print(f"Epoch {self.current_epoch} | Step {batch_idx} | Loss: {loss.item():.4f}")
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

if __name__ == "__main__":
    dataset = train_dataset#DummyVideoDataset(num_samples=5)  # Small set to test overfitting
    dataloader = train_loader#DataLoader(dataset, batch_size=1, shuffle=True)

    model = MAEPretrainSwin(mask_ratio=0.)

    trainer = pl.Trainer(
        max_epochs=100,
        accelerator='auto',
        logger=False,
        enable_checkpointing=False,
    )

    trainer.fit(model, dataloader)


Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | encoder   | SwinTransformer3d | 27.9 M | train
1 | decoder   | Sequential        | 2.7 M  | train
2 | criterion | MSELoss           | 0      | train
--------------------------------------------------------
30.5 M    Trainable params
0         Non-trainable params
30.5 M    Total params
122.023   Total estimated model params size (MB)
183       Modules in train mode
0         Modules in eval mode


Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s] 

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [16]:
import torch
from transformers import AutoImageProcessor, TimesformerModel
import torchvision.transforms as T

# Convert to PIL and then to 3 channels
pil_transform = T.Compose([
    T.ToPILImage(),                    # convert (C, H, W) to PIL
    T.Grayscale(num_output_channels=3),  # convert to 3 channels
])

# from transformers import TimeSformerModel, TimeSformerConfig, TimeSformerImageProcessor
import numpy as np

# Load pretrained TimeSformer
model = TimesformerModel.from_pretrained("facebook/timesformer-hr-finetuned-k600")
processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400")

# Dummy video tensor: shape [batch, num_frames, height, width, channels]
video = torch.rand(3,8, 224, 224)  # 8 frames of 224x224 RGB

image = video.permute(1,0,2,3)
frames = [pil_transform(frame.squeeze(0)) for frame in image] 

encoding = processor(
    [frame for frame in frames],   # list of PIL
    return_tensors='pt'
    )
processed = encoding["pixel_values"].squeeze(0)
print(processed)

# Preprocess
# video_np = pixel_values.numpy()
# processed = processor(list(video_np), return_tensors="pt")

# Run through model to get patch embeddings
with torch.no_grad():
    outputs = model(processed, output_hidden_states=True)
    embeddings = outputs.last_hidden_state  # [1, num_patches+1, hidden_dim]

print(f"Embeddings shape: {embeddings.shape}")  # [1, num_tokens, 768]

# Drop CLS token and mask patches
patches = embeddings[:, 1:, :]  # Remove CLS token, shape [1, N, D]
B, N, D = patches.shape

# Random masking (75%)
mask_ratio = 0.75
len_keep = int(N * (1 - mask_ratio))

noise = torch.rand(B, N)
ids_shuffle = torch.argsort(noise, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
ids_restore = torch.argsort(ids_shuffle, dim=1)

# Gather visible patches
patches_visible = torch.gather(patches, 1, ids_keep.unsqueeze(-1).repeat(1, 1, D))
print(f"Visible patches shape: {patches_visible.shape}")


tensor([[[[ 0.0566,  0.3355,  0.1786,  ..., -1.2331, -1.1111,  0.8758],
          [ 0.9455,  1.3290,  0.7364,  ...,  1.1547,  0.1089, -0.0479],
          [ 0.6144,  0.1612,  1.5033,  ..., -0.1874,  0.4052,  0.0392],
          ...,
          [-1.0588,  1.5556,  0.7364,  ...,  0.5447,  2.2004, -0.4314],
          [ 0.4749,  0.9630,  0.7887,  ..., -0.7800,  1.8693,  1.7298],
          [ 0.8758, -0.1351, -0.9717,  ...,  0.5621,  0.2832,  0.5447]],

         [[ 0.0566,  0.3355,  0.1786,  ..., -1.2331, -1.1111,  0.8758],
          [ 0.9455,  1.3290,  0.7364,  ...,  1.1547,  0.1089, -0.0479],
          [ 0.6144,  0.1612,  1.5033,  ..., -0.1874,  0.4052,  0.0392],
          ...,
          [-1.0588,  1.5556,  0.7364,  ...,  0.5447,  2.2004, -0.4314],
          [ 0.4749,  0.9630,  0.7887,  ..., -0.7800,  1.8693,  1.7298],
          [ 0.8758, -0.1351, -0.9717,  ...,  0.5621,  0.2832,  0.5447]],

         [[ 0.0566,  0.3355,  0.1786,  ..., -1.2331, -1.1111,  0.8758],
          [ 0.9455,  1.3290,  

ValueError: not enough values to unpack (expected 5, got 4)

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import AutoImageProcessor, TimesformerModel
from einops import rearrange
import random
from PIL import Image

# ----------- Configuration -----------
device = "cuda" if torch.cuda.is_available() else "cpu"
num_frames = 8
height = width = 224
patch_dim = 768  # TimeSformer hidden size
mask_ratio = 0.75

# ----------- Load TimeSformer encoder -----------
model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400").to(device)
processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400")

# ----------- Dummy video data (uint8) -----------
# Simulate a batch of 1 video: (T, C, H, W)
video = torch.randint(0, 255, (num_frames, 3, height, width), dtype=torch.uint8)

# ----------- Convert to PIL frames -----------
to_pil = transforms.ToPILImage()
frames = [to_pil(frame) for frame in video]  # List of PIL images

# ----------- Preprocess video -----------
encoding = processor(frames, return_tensors='pt')
pixel_values = encoding["pixel_values"].to(device)  # (1, T, C, H, W)

# ----------- Encoder forward pass -----------
with torch.no_grad():
    outputs = model(pixel_values=pixel_values, output_hidden_states=True)
    all_tokens = outputs.last_hidden_state  # (1, num_tokens, 768)

# ----------- Remove CLS token -----------
patch_tokens = all_tokens[:, 1:, :]  # (1, N, D)
B, N, D = patch_tokens.shape

# ----------- Random masking -----------
def random_masking(x, mask_ratio):
    len_keep = int(x.shape[1] * (1 - mask_ratio))
    noise = torch.rand(B, N, device=x.device)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    return x_masked, ids_restore, ids_keep

visible_tokens, ids_restore, ids_keep = random_masking(patch_tokens, mask_ratio)

# ----------- Tiny decoder -----------
class MAEDecoder(nn.Module):
    def __init__(self, input_dim=768, decoder_dim=512, num_layers=2):
        super().__init__()
        self.linear_in = nn.Linear(input_dim, decoder_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=8)
        self.decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.linear_out = nn.Linear(decoder_dim, input_dim)

    def forward(self, visible_tokens, ids_restore):
        B, N_visible, D = visible_tokens.shape
        N_total = ids_restore.shape[1]

        # Project to decoder dim
        x = self.linear_in(visible_tokens)

        # Create [MASK] tokens for missing patches
        mask_token = nn.Parameter(torch.zeros(1, 1, D, device=x.device))
        x_full = mask_token.expand(B, N_total, -1).clone()
        x_full.scatter_(1, ids_keep.unsqueeze(-1), x)

        # Decode
        x_decoded = self.decoder(x_full)
        return self.linear_out(x_decoded)

decoder = MAEDecoder().to(device)

# ----------- Loss and optimizer -----------
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)

# ----------- Forward + Loss -----------
decoder_output = decoder(visible_tokens, ids_restore)  # (B, N, D)

# Compute loss only on masked tokens
mask = torch.ones(B, N, device=device).scatter(1, ids_keep, 0).bool()  # (B, N)
loss = criterion(decoder_output[mask], patch_tokens[mask])

print(f"MAE loss: {loss.item():.4f}")

# ----------- Backprop -----------
loss.backward()
optimizer.step()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (392x768 and 512x512)

In [26]:
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import TimesformerModel, AutoImageProcessor
from einops import rearrange
import random

# ----------- Configuration -----------
device = "cuda" if torch.cuda.is_available() else "cpu"
num_frames = 8
height = width = 224
mask_ratio = 0.75

# ----------- Load TimeSformer encoder -----------
model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400").to(device)
processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400")

# ----------- Dummy video data -----------
video = torch.randint(0, 255, (num_frames, 3, height, width), dtype=torch.uint8)
to_pil = transforms.ToPILImage()
frames = [to_pil(frame) for frame in video]  # List of (H, W, C) PIL images

# ----------- Preprocess video frames -----------
encoding = processor(frames, return_tensors='pt')
pixel_values = encoding["pixel_values"].to(device)  # shape: (1, T, C, H, W)

# ----------- Embed all patches (no encoder yet) -----------
# We use the TimeSformer’s patch embedding layer directly
patch_embed = model.embeddings.patch_embeddings  # Conv3D patch embedding
B, T, C, H, W = pixel_values.shape
# video_embeds = patch_embed(pixel_values)  # shape: (B, D, T', H', W')
video_embeds = patch_embed(pixel_values)[0]  # <- FIXED HERE ✅
patches = video_embeds

# Flatten to patches
# patches = rearrange(video_embeds, 'b d t h w -> b (t h w) d')  # [B, N, D]
B, N, D = patches.shape

# ----------- Random masking BEFORE encoder -----------
def random_masking(x, mask_ratio):
    len_keep = int(x.shape[1] * (1 - mask_ratio))
    noise = torch.rand(B, N, device=x.device)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    ids_keep = ids_shuffle[:, :len_keep]
    x_visible = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    return x_visible, ids_restore, ids_keep

patches_visible, ids_restore, ids_keep = random_masking(patches, mask_ratio)

# ----------- Encoder input: visible patches only -----------
# Add dummy CLS token to match input shape
cls_token = model.embeddings.cls_token.expand(B, -1, -1)
encoder_input = torch.cat([cls_token, patches_visible], dim=1)

# # Positional embeddings (use subset of visible)
# pos_embed = model.embeddings.position_embeddings[:, 1:, :]  # exclude CLS pos
# pos_visible = torch.gather(pos_embed.expand(B, -1, -1), 1, ids_keep.unsqueeze(-1).repeat(1, 1, D))
# pos_input = torch.cat([model.embeddings.position_embeddings[:, :1, :], pos_visible], dim=1)
# Positional embeddings
pos_embed = model.embeddings.position_embeddings[:, 1:, :]  # exclude CLS
pos_embed = pos_embed.expand(B, -1, -1)                     # [B, N, D]
pos_visible = torch.gather(pos_embed, 1, ids_keep.unsqueeze(-1).repeat(1, 1, D))  # [B, N_visible, D]

# Add CLS position embedding (expand to B)
cls_pos = model.embeddings.position_embeddings[:, :1, :].expand(B, -1, -1)  # [B, 1, D]
pos_input = torch.cat([cls_pos, pos_visible], dim=1)  # [B, N_visible+1, D]


# Pass visible tokens to encoder
with torch.no_grad():
    encoded = model.encoder(encoder_input + pos_input)[0][:, 1:, :]  # Remove CLS

# ----------- Tiny decoder -----------
class MAEDecoder(nn.Module):
    def __init__(self, input_dim=768, decoder_dim=512, num_layers=2):
        super().__init__()
        self.linear_in = nn.Linear(input_dim, decoder_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=8)
        self.decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.linear_out = nn.Linear(decoder_dim, input_dim)

    def forward(self, visible_tokens, ids_restore):
        B, N_visible, D = visible_tokens.shape
        N_total = ids_restore.shape[1]

        # Project to decoder dim
        x = self.linear_in(visible_tokens)

        # Create [MASK] tokens for missing patches
        mask_token = nn.Parameter(torch.zeros(1, 1, x.size(-1), device=x.device))
        x_full = mask_token.expand(B, N_total, -1).clone()
        x_full.scatter_(1, ids_keep.unsqueeze(-1), x)

        # Decode
        x_decoded = self.decoder(x_full)
        return self.linear_out(x_decoded)

decoder = MAEDecoder().to(device)

# ----------- Loss and optimizer -----------
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)

# ----------- Forward + Loss -----------
reconstructed = decoder(encoded, ids_restore)  # (B, N, D)

# Compute loss only on masked patches
mask = torch.ones(B, N, device=device).scatter(1, ids_keep, 0).bool()
loss = criterion(reconstructed[mask], patches[mask])

print(f"MAE loss: {loss.item():.4f}")

# ----------- Backprop -----------
loss.backward()
optimizer.step()


RuntimeError: shape '[8, 0, 14, 8, 768]' is invalid for input of size 301056