In [25]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from transformers import AutoImageProcessor
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
# from timesformer_pytorch import TimeSformer

import random
import threading
import glob

import numpy as np
import wandb
from torch.utils.data import DataLoader
import os
import random
import cv2
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch.optim import AdamW
import segmentation_models_pytorch as smp
import numpy as np
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch
from warmup_scheduler import GradualWarmupScheduler
import PIL.Image

PIL.Image.MAX_IMAGE_PIXELS = 933120000

import utils
import models.swin as swin
import models.timesformer_hug as timesformer_hug

class TimesformerDataset(Dataset):
    def __init__(self, images, cfg, xyxys=None, labels=None, transform=None):
        self.images = images
        self.cfg = cfg
        self.labels = labels
        
        self.transform = transform
        self.xyxys=xyxys
        # self.video_transform = T.Compose([
        #     T.ConvertImageDtype(torch.float32), 
        #     T.Normalize(mean=[0.5], std=[0.5])   # shift and scale to [-1, 1]
        #     ])
        self.video_transform = T.Compose([
            T.ConvertImageDtype(torch.float32), 
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.xyxys is not None: #VALID
            image = self.images[idx]
            label = self.labels[idx]
            xy=self.xyxys[idx]
            if self.transform:
                data = self.transform(image=image, mask=label)
                image = data['image'].unsqueeze(0)
                label = data['mask']
                label=F.interpolate(label.unsqueeze(0),(self.cfg.size//16,self.cfg.size//16)).squeeze(0)
            
            # image = image.permute(1,0,2,3)
            # image = torch.stack([self.video_transform(f) for f in image]) # list of frames
            image = image.permute(1,0,2,3)
            image = image.repeat(1, 3, 1, 1)
            image = torch.stack([self.video_transform(f) for f in image]) # list of frames
            return image, label
        else:
            image = self.images[idx]
            label = self.labels[idx]
                        
            if self.transform:
                data = self.transform(image=image, mask=label)
                image = data['image'].unsqueeze(0)
                label = data['mask']
                label=F.interpolate(label.unsqueeze(0),(self.cfg.size//16,self.cfg.size//16)).squeeze(0)
                
            # image = image.permute(1,0,2,3)
            # image = torch.stack([self.video_transform(f) for f in image]) # list of frames
            # return image, label
            image = image.permute(1,0,2,3)
            image = image.repeat(1, 3, 1, 1)
            image = torch.stack([self.video_transform(f) for f in image]) # list of frames
            return image, label

In [26]:
class CFG:
    # ============== comp exp name =============
    current_dir = './'
    segment_path = './train_scrolls/'
    
    start_idx = 20
    in_chans = 16
    
    size = 128
    tile_size = 128
    stride = tile_size // 1
    
    train_batch_size =  10 # 32
    valid_batch_size = 1
    lr = 1e-4
    num_workers = 8
    # ============== model cfg =============
    scheduler = 'cosine'#, 'linear'
    epochs = 16
    warmup_factor = 10
    
    # Change the size of fragments
    frags_ratio1 = ['frag','re']
    frags_ratio2 = ['202','s4','left']
    ratio1 = 2
    ratio2 = 1
    
    # ============== fold =============
    segments = ['rect5','20231215151901'] 
    valid_id = 'rect5'#20231215151901'
    
    # ============== fixed =============
    min_lr = 1e-7
    weight_decay = 1e-6
    max_grad_norm = 100
    num_workers = 8
    seed = 0
    
    # ============== comp exp name =============
    comp_name = 'vesuvius'
    exp_name = 'pretraining_all'

    outputs_path = f'./outputs/{comp_name}/{exp_name}/'
    model_dir = outputs_path + \
        f'{comp_name}-models/'
        
    # ============== augmentation =============
    train_aug_list = [
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        # # # A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(rotate_limit=360,shift_limit=0.15,scale_limit=0.1,p=0.75),
        # A.OneOf([
        #         A.GaussNoise(var_limit=[10, 50]),
        #         A.GaussianBlur(),
        #         A.MotionBlur(),
        #         ], p=0.4),
        # A.CoarseDropout(max_holes=2, max_width=int(size * 0.2), max_height=int(size * 0.2), 
        #                 mask_fill_value=0, p=0.5),
        # A.Normalize(
        #     mean= [0] * in_chans,
        #     std= [1] * in_chans
        # ),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        # A.Normalize(
        #     mean= [0] * in_chans,
        #     std= [1] * in_chans
        # ),
        ToTensorV2(transpose_mask=True),  
    ]
    
def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)
    return aug   


# End any existing run (if still active)
if wandb.run is not None:
    wandb.finish()
        
utils.cfg_init(CFG)
torch.set_float32_matmul_precision('medium')

fragment_id = CFG.valid_id
run_slug=f'SWIN_{CFG.segments}_valid={CFG.valid_id}_size={CFG.size}_lr={CFG.lr}_in_chans={CFG.in_chans}'
valid_mask_gt = cv2.imread(f"{CFG.segment_path}{fragment_id}/{fragment_id}_inklabels.png", 0)

if any(sub in fragment_id for sub in CFG.frags_ratio1):
    scale = 1 / CFG.ratio1
    new_w = int(valid_mask_gt.shape[1] * scale)
    new_h = int(valid_mask_gt.shape[0] * scale)
    valid_mask_gt = cv2.resize(valid_mask_gt, (new_w, new_h), interpolation=cv2.INTER_AREA)

elif any(sub in fragment_id for sub in CFG.frags_ratio2):
    scale = 1 / CFG.ratio2
    new_w = int(valid_mask_gt.shape[1] * scale)
    new_h = int(valid_mask_gt.shape[0] * scale)
    valid_mask_gt = cv2.resize(valid_mask_gt, (new_w, new_h), interpolation=cv2.INTER_AREA)
pred_shape=valid_mask_gt.shape

train_images, train_masks, valid_images, valid_masks, valid_xyxys = utils.get_train_valid_dataset(CFG)
train_images = train_images[10:20]
train_masks = train_masks[10:20]

print('train_images',train_images[0].shape)
print("Length of train images:", len(train_images))

valid_xyxys = np.stack(valid_xyxys)
train_dataset = TimesformerDataset(
    train_images, CFG, labels=train_masks, transform=get_transforms(data='valid', cfg=CFG))
valid_dataset = TimesformerDataset(
    valid_images, CFG, xyxys=valid_xyxys, labels=valid_masks, transform=get_transforms(data='valid', cfg=CFG))

train_loader = DataLoader(train_dataset,
                            batch_size=CFG.train_batch_size,
                            shuffle=True,
                            num_workers=0, pin_memory=True, drop_last=True,
                            )
valid_loader = DataLoader(valid_dataset,
                            batch_size=CFG.valid_batch_size,
                            shuffle=False,
                            num_workers=0, pin_memory=True, drop_last=True)

print(f"Train loader length: {len(train_loader)}")
print(f"Valid loader length: {len(valid_loader)}")

reading rect5




100%|██████████| 16/16 [00:00<00:00, 22.30it/s]


 Shape of rect5 segment: (1536, 2048, 16)
reading 20231215151901


100%|██████████| 16/16 [00:02<00:00,  5.55it/s]


 Shape of 20231215151901 segment: (3200, 10496, 16)
train_images (128, 128, 16)
Length of train images: 10
Train loader length: 1
Valid loader length: 78


In [27]:
for x,y in train_loader:
    print(x.max())
    print(x.min())
    break

tensor(1.6814)
tensor(-1.6042)


In [28]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_reconstruction(original, reconstructed, sample_idx=0, num_frames=4):
    """
    Visualize original and reconstructed video frames side by side.

    Args:
        original: tensor (B, T, C, H, W) original video batch
        reconstructed: tensor (B, T, C, H, W) reconstructed video batch
        sample_idx: int, index in batch to visualize
        num_frames: int, number of frames to display
    """
    orig = original[sample_idx]     # (T, C, H, W)
    recon = reconstructed[sample_idx]  # (T, C, H, W)

    # If grayscale, squeeze channel dim
    if orig.shape[1] == 1:
        orig = orig.squeeze(1)
        recon = recon.squeeze(1)

    # Clamp and convert to numpy
    orig = orig.cpu().numpy()
    recon = recon.cpu().detach().numpy()

    fig, axes = plt.subplots(2, num_frames, figsize=(3 * num_frames, 6))

    for i in range(num_frames):
        # Original frame
        ax = axes[0, i]
        ax.imshow(orig[i], cmap='gray')
        ax.set_title(f"Original Frame {i}")
        ax.axis('off')

        # Reconstructed frame
        ax = axes[1, i]
        ax.imshow(recon[i], cmap='gray')
        ax.set_title(f"Reconstructed Frame {i}")
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# TIMS

In [29]:

import math
from torch.optim import AdamW
from warmup_scheduler import GradualWarmupScheduler
from transformers import SegformerForSemanticSegmentation
from torchvision.models.video import swin_transformer
import albumentations as A
from transformers import AutoImageProcessor, TimesformerModel
from transformers import TimesformerModel, TimesformerConfig

import numpy as np

class MAEPretrain(pl.LightningModule):
    def __init__(self, lr=1e-3, mask_ratio=0.75, embed_dim=768, decoder_dim=512, decoder_layers=4):
        super().__init__()
        self.save_hyperparameters()
        
        config = TimesformerConfig(
            num_frames=16,
            image_size=128,
            patch_size=8,
            num_channels=1,
            attention_type="divided_space_time",
        )
        self.encoder = TimesformerModel(config)
        
        self.patch_size = config.patch_size
        self.input_T = config.num_frames
        self.input_H = config.image_size
        self.input_W = config.image_size
        self.mask_ratio = self.hparams.mask_ratio
        mse_loss = nn.MSELoss()
        l1_loss = nn.L1Loss()

        self.criterion = lambda pred, target: 1 * mse_loss(pred, target) #+ 0.5 * l1_loss(pred, target)        
        
        self.N = self.input_T * self.input_H * self.input_W // (self.patch_size**2)
        print(f"Total patches: {self.N}")

        target = (1- self.mask_ratio) * self.N  # example: 500        
        # Start by finding max possible y (perfect square root) that doesn't exceed target/input_T
        max_y = math.floor(math.sqrt(target / self.input_T))

        # Generate candidate numbers (num = y*y * input_T)
        candidates = [y*y * self.input_T for y in range(max_y, 0, -1)]

        # Pick the candidate closest to the target
        closest = min(candidates, key=lambda x: abs(x - target))
        
        self.unmasked_patches =  closest

        print("Unmasked_patches:", closest)

        print(f"Actual patches used: {self.unmasked_patches}/{self.N} : {self.unmasked_patches/self.N:.2f}")


        # Transformer decoder components
        self.decoder_embed = nn.Linear(embed_dim, decoder_dim)
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.N, decoder_dim))
        
        decoder_layer = nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=8, dim_feedforward=decoder_dim)
        self.decoder_transformer = nn.TransformerEncoder(decoder_layer, num_layers=decoder_layers)
        self.decoder_pred = nn.Sequential(
            nn.Linear(decoder_dim, self.patch_size**2),
            nn.Tanh()  # ensures outputs in [-1, 1]
        )

        # Mask token for masked patches in decoder
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        nn.init.normal_(self.mask_token, std=0.02)
        

    def patchify(self, x):
        B, T, C, H, W = x.shape # x: (B, T, C, H, W)
        
        # Only unfold spatial dimensions
        x = x.permute(0, 2, 1, 3, 4)  # (B, C, T, H, W)
        x = x.unfold(3, self.patch_size, self.patch_size).unfold(4, self.patch_size, self.patch_size)
        # x: (B, C, T, H_patches, W_patches, patch_size, patch_size)
        
        H_patches = x.size(3)
        W_patches = x.size(4)
        
        # Move patch grid before time: (B, ph, pw, T, C, ps, ps)
        x = x.permute(0, 3, 4, 2, 1, 5, 6)

        # Flatten patch: (B, ph*pw*pt, C*ps*ps)
        x = x.reshape(B, H_patches * W_patches* T , C *self.patch_size**2)

        return x
    

    def unpatchify(self, x, patch_shape):
        """
        x: (B, N, D) from masked patches
        patch_shape: (pt, ph, pw) where
            pt = number of frames (T),
            ph = patches along height,
            pw = patches along width
        """
        B, N, D = x.shape
        pt, ph, pw = patch_shape
        ps = self.patch_size
        C = D // (ps**2)  # only spatial patches
        
        assert ph * pw * pt == N, "Spatial patch count mismatch"

        #     # (B, ph, pw, pt, C, ps, ps)
        # x = x.view(B, pt, ph, pw, C, ps, ps)
        x = x.view(B, ph, pw, pt, C, ps, ps)
        # x = x.permute(0, 4,1, 3, 5, 2, 6).reshape(B, C, pt, ph * ps, pw * ps)
        # # Restore to (B, C, T, H, W)
        x = x.permute(0, 4, 3, 1, 5, 2, 6).reshape(B, C, pt, ph * ps, pw * ps)

        return x

    def random_masking(self, x, mask_ratio=0.75):
        """
        MAE-style random masking with restore indices.
        x: (B, N, D)
        Returns:
            x_masked: (B, n_keep, D)    - visible patches
            ids_keep: (B, n_keep)       - indices of kept patches
            ids_masked: (B, n_mask)     - indices of masked patches
            ids_restore: (B, N)         - to restore original order
        """
        B, N, D = x.shape
        n_keep = self.unmasked_patches

        ids_keep = []
        ids_masked = []
        ids_restore = []

        for b in range(B):
            # 1. Random permutation of all patches
            perm = torch.randperm(N, device=x.device)
            keep = perm[:n_keep]
            masked = perm[n_keep:]

            # 2. Save indices
            ids_keep.append(keep)
            ids_masked.append(masked)

            # 3. Build restore index (inverse of permutation)
            ids_restore_b = torch.empty_like(perm)
            ids_restore_b[perm] = torch.arange(N, device=x.device)
            ids_restore.append(ids_restore_b)
            # ids_restore_b = torch.empty(N, device=x.device, dtype=torch.long)
            # ids_restore_b[keep]   = torch.arange(n_keep, device=x.device)
            # ids_restore_b[masked] = torch.arange(n_keep, N, device=x.device)
            # ids_restore.append(ids_restore_b)

        ids_keep = torch.stack(ids_keep, dim=0)      # (B, n_keep)
        ids_masked = torch.stack(ids_masked, dim=0)  # (B, n_mask)
        ids_restore = torch.stack(ids_restore, dim=0) # (B, N)

        # 4. Gather kept tokens
        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))

        return x_masked, ids_keep, ids_masked, ids_restore

    def forward(self, x):
        B, T, C, H, W = x.shape # (B, T, C, H, W)
        
        # 1. Patchify input video
        x_patched = self.patchify(x)  # (B, N, D)
        print('x_patched',x_patched.shape) # (B,N,T,D)
        
        N = x_patched.shape[1] # 4096

        # 2. Mask patches
        x_masked, ids_keep, ids_masked, ids_restore = self.random_masking(x_patched, self.mask_ratio)
        
        # print('x_masked',x_masked.shape) # (B,N,T,D)
        ids_keep = ids_keep.long() # Ids of unmasked
        ids_masked = ids_masked.long()
        ids_restore = ids_restore.long()

        # Calculate masked patch indices
        all_ids = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)  # (B, N)
        mask = torch.ones_like(all_ids, dtype=torch.bool)
        mask.scatter_(1, ids_keep, False)
        
        pt = T  
        ph = pw = int((self.unmasked_patches // pt) ** 0.5)
        assert ph * pw * pt == self.unmasked_patches, "Patch grid mismatch"

        # We know the tube shape is (pt, square_size, square_size)
        x_masked_video = self.unpatchify(x_masked, (pt, ph, pw))  # (B, C, T, H_mask, W_mask)
        print('x_masked_unpatchify',x_masked_video.shape)
        # x_masked_video = x_masked_video.permute(0,2,1,3,4)  # (B,C,T,H,W)
        
        
        # 4. Encoder forward on masked video        
        x_masked_video = x_masked_video.permute(0,2,1,3,4) # (B,T,C,H,W)
        outputs = self.encoder(x_masked_video, output_hidden_states=True)
        tokens = outputs.last_hidden_state[:,1:,:]  # tuple of all hidden layers
        print('tokens',tokens.shape) # (B, n_visible, D)
        # Group first
        tokens = tokens.view(B, ph, pw,pt, self.hparams.embed_dim)  # (B, 16, 8, 8, D)
        # tokens = tokens.view(B, ph, pw, T, self.hparams.embed_dim)  # (B, 16, 8, 8, D)
        tokens = tokens.permute(0, 3, 1, 2, 4).contiguous()         # (B, ph, pw, T, D)  <-- matches your compact cube order


        # 5. Embed encoder features to decoder_dim
        x_vis = self.decoder_embed(tokens)  # (B, n_visible, decoder_dim)
        
        x_vis = x_vis.view(B,-1,self.hparams.decoder_dim)
        print(x_vis.shape)

        # 6. Prepare mask tokens for masked patches
        # print(ids_masked.shape[1])

        mask_tokens = self.mask_token.expand(B, ids_masked.shape[1], -1)  # (B, n_masked, decoder_dim)
        # x_ = torch.cat([x_vis, mask_tokens], dim=1)   # (B, N, D) but shuffled
        # x_dec = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x_vis.shape[2]))

        # print(mask_tokens.shape)
        # 7. Create full sequence tensor for decoder input
        # Restore to original order
        x_ = torch.cat([x_vis, mask_tokens], dim=1)  # (B, n_keep + n_masked, D)
        x_dec = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x_.shape[2]))
        # 8. Add positional embedding
        x_dec = x_dec + self.decoder_pos_embed
    


        # 9. Decode full sequence
        x_dec = self.decoder_transformer(x_dec)
        # print('x_dec',x_dec.shape)
        pred = self.decoder_pred(x_dec)  # (B, N, patch_dim)
        
        # pred_masked = pred.clone()
        # pred_masked[mask == 0] = 0  # zero out unmasked tokens
        
        # 3. Unpatchify visible patches to video for encoder
        ph = pw = int((self.N // pt) ** 0.5)
        recon = self.unpatchify(pred, (pt, ph, pw))  # (B, C, T, H, W)
        recon = recon.permute(0, 2, 1, 3, 4)  # (B, T, C, H, W)

        return recon, x_masked_video, mask, ids_masked, pred, x_patched

    def training_step(self, batch, batch_idx):
        x, y = batch  # (B, 1, T, H, W)
        B = x.shape[0]

        recon, x_masked, mask, ids_masked, pred, target = self(x)

        # Compute loss only on masked patches
        mask = mask.bool()
        
        # Expand mask for (B, N, D)
        mask_expanded = mask.unsqueeze(-1).expand_as(pred)  # (5, 4096, 64)

        # Apply masking
        pred_masked   = pred[mask_expanded]    # (num_masked * D,)
        target_masked = target[mask_expanded]  # (num_masked * D,)


        # Compute loss only on masked patches
        loss = self.criterion(pred_masked, target_masked)
        self.log("train_loss", loss, prog_bar=True, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        recon, _, mask, _, _, _ = self(x)
        loss = self.criterion(recon, x)
        self.log('val_loss', loss, prog_bar=True)
        # Save first batch to visualize later
        if batch_idx == 0:
            self.val_batch_for_viz = (x, mask, recon)
        return loss

    def on_validation_epoch_end(self):
        if hasattr(self, 'val_batch_for_viz'):
            x, mask, recon = self.val_batch_for_viz
            visualize_reconstruction(x, recon, sample_idx=0, num_frames=16)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-6)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        return [optimizer], [scheduler]


model = MAEPretrain()
trainer = pl.Trainer(
    max_epochs=200,
    accelerator='auto',
    log_every_n_steps=20,
    check_val_every_n_epoch=50,
)
trainer.fit(model, train_loader, train_loader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                | Type               | Params | Mode 
-------------------------------------------------------------------
0 | encoder             | TimesformerModel   | 120 M  | train
1 | decoder_embed       | Linear             | 393 K  | train
2 | decoder_transformer | TransformerEncoder | 6.3 M  | train
3 | decoder_pred        | Sequential         | 32.8 K | train
  | other params   

Total patches: 4096
Unmasked_patches: 1024
Actual patches used: 1024/4096 : 0.25
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]x_patched torch.Size([10, 4096, 192])


/home/ubuntu/miniconda3/envs/dion/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/ubuntu/miniconda3/envs/dion/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


x_masked_unpatchify torch.Size([10, 3, 16, 64, 64])


RuntimeError: Given groups=1, weight of size [768, 1, 8, 8], expected input[160, 3, 64, 64] to have 1 channels, but got 3 channels instead

# SWIN

In [34]:

import math
from torch.optim import AdamW
from warmup_scheduler import GradualWarmupScheduler
from transformers import SegformerForSemanticSegmentation
from torchvision.models.video import swin_transformer
import albumentations as A
from transformers import AutoImageProcessor, TimesformerModel
from transformers import TimesformerModel, TimesformerConfig

import numpy as np

class MAEPretrain(pl.LightningModule):
    def __init__(self, lr=1e-3, mask_ratio=0.75, embed_dim=768, decoder_dim=512, decoder_layers=4):
        super().__init__()
        self.save_hyperparameters()
        
    
        self.encoder  = swin_transformer.swin3d_t(weights="KINETICS400_V1") #KINETICS400_IMAGENET22K_V1

        
        self.patch_size = 16
        self.input_T = 16
        self.input_H = 224
        self.input_W = 224
        self.tubelet_size = 2
        self.mask_ratio = self.hparams.mask_ratio
        mse_loss = nn.MSELoss()
        l1_loss = nn.L1Loss()

        self.criterion = lambda pred, target: 1 * mse_loss(pred, target) #+ 0.5 * l1_loss(pred, target)        
        
        # self.N = self.input_T * self.input_H * self.input_W // (self.patch_size**2*2)
        self.N = (self.input_T // self.tubelet_size) * \
         (self.input_H // self.patch_size) * \
         (self.input_W // self.patch_size)
        print(f"Total patches: {self.N}")

        target = (1- self.mask_ratio) * self.N    
        print("Target unmasked patches:", target)
        # Start by finding max possible y (perfect square root) that doesn't exceed target/input_T
        max_y = math.floor(math.sqrt(target / self.input_T))

        # Generate candidate numbers (num = y*y * input_T)
        candidates = [y*y * self.input_T for y in range(max_y, 0, -1)]

        # Pick the candidate closest to the target
        closest = min(candidates, key=lambda x: abs(x - target))
        
        self.unmasked_patches =  int(target)

        print("Unmasked_patches:", closest)

        print(f"Actual patches used: {self.unmasked_patches}/{self.N} : {self.unmasked_patches/self.N:.2f}")


        # Transformer decoder components
        self.decoder_embed = nn.Linear(embed_dim, decoder_dim)
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.N, decoder_dim))
        
        decoder_layer = nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=8, dim_feedforward=decoder_dim)
        self.decoder_transformer = nn.TransformerEncoder(decoder_layer, num_layers=decoder_layers)
        self.decoder_pred = nn.Sequential(
            nn.Linear(decoder_dim, self.patch_size**2),
            nn.Tanh()  # ensures outputs in [-1, 1]
        )

        # Mask token for masked patches in decoder
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        nn.init.normal_(self.mask_token, std=0.02)
        

    # def patchify(self, x):
    #     B, T, C, H, W = x.shape # x: (B, T, C, H, W)
        
    #     # Only unfold spatial dimensions
    #     x = x.permute(0, 2, 1, 3, 4)  # (B, C, T, H, W)
    #     x = x.unfold(3, self.patch_size, self.patch_size).unfold(4, self.patch_size, self.patch_size)
    #     # x: (B, C, T, H_patches, W_patches, patch_size, patch_size)
        
    #     H_patches = x.size(3)
    #     W_patches = x.size(4)
        
    #     # Move patch grid before time: (B, ph, pw, T, C, ps, ps)
    #     x = x.permute(0, 3, 4, 2, 1, 5, 6)

    #     # Flatten patch: (B, ph*pw*pt, C*ps*ps)
    #     x = x.reshape(B, H_patches * W_patches* T , C *self.patch_size**2)

    #     return x
    def patchify(self, x):
        """
        x: (B, T, C, H, W)
        Returns:
            patches: (B, N, patch_dim)
            where N = (T//tubelet_size) * (H//patch_size) * (W//patch_size)
                    patch_dim = C * tubelet_size * patch_size * patch_size
        """
        B, T, C, H, W = x.shape
        tubelet = self.tubelet_size   # e.g., 2
        ps = self.patch_size          # e.g., 4

        # (B, T, C, H, W) → (B, C, T, H, W)
        x = x.permute(0, 2, 1, 3, 4)

        # Unfold temporal and spatial dims
        x = x.unfold(2, tubelet, tubelet) \
            .unfold(3, ps, ps) \
            .unfold(4, ps, ps)
        # shape: (B, C, T_patches, H_patches, W_patches, tubelet, ps, ps)

        T_patches = x.size(2)
        H_patches = x.size(3)
        W_patches = x.size(4)

        # Reorder: (B, T_p, H_p, W_p, tubelet, ps, ps, C)
        x = x.permute(0, 2, 3, 4, 5, 6, 7, 1)

        # Flatten each patch: (B, N, C * tubelet * ps * ps)
        x = x.reshape(B, T_patches * H_patches * W_patches,
                    C * tubelet * ps * ps)

        return x

    

    def unpatchify(self, x, patch_shape):
        """
        x: (B, N, D) from masked patches
        patch_shape: (pt, ph, pw) where
            pt = number of frames (T),
            ph = patches along height,
            pw = patches along width
        """
        B, N, D = x.shape
        pt, ph, pw = patch_shape
        ps = self.patch_size
        C = D // (ps**2)  # only spatial patches
        
        assert ph * pw * pt == N, "Spatial patch count mismatch"

        #     # (B, ph, pw, pt, C, ps, ps)
        # x = x.view(B, pt, ph, pw, C, ps, ps)
        x = x.view(B, ph, pw, pt, C, ps, ps)
        # x = x.permute(0, 4,1, 3, 5, 2, 6).reshape(B, C, pt, ph * ps, pw * ps)
        # # Restore to (B, C, T, H, W)
        x = x.permute(0, 4, 3, 1, 5, 2, 6).reshape(B, C, pt, ph * ps, pw * ps)

        return x

    def random_masking(self, x, mask_ratio=0.75):
        """
        MAE-style random masking with restore indices.
        x: (B, N, D)
        Returns:
            x_masked: (B, n_keep, D)    - visible patches
            ids_keep: (B, n_keep)       - indices of kept patches
            ids_masked: (B, n_mask)     - indices of masked patches
            ids_restore: (B, N)         - to restore original order
        """
        B, N, D = x.shape
        n_keep = self.unmasked_patches

        ids_keep = []
        ids_masked = []
        ids_restore = []

        for b in range(B):
            # 1. Random permutation of all patches
            perm = torch.randperm(N, device=x.device)
            keep = perm[:n_keep]
            masked = perm[n_keep:]

            # 2. Save indices
            ids_keep.append(keep)
            ids_masked.append(masked)

            # 3. Build restore index (inverse of permutation)
            ids_restore_b = torch.empty_like(perm)
            ids_restore_b[perm] = torch.arange(N, device=x.device)
            ids_restore.append(ids_restore_b)
            # ids_restore_b = torch.empty(N, device=x.device, dtype=torch.long)
            # ids_restore_b[keep]   = torch.arange(n_keep, device=x.device)
            # ids_restore_b[masked] = torch.arange(n_keep, N, device=x.device)
            # ids_restore.append(ids_restore_b)

        ids_keep = torch.stack(ids_keep, dim=0)      # (B, n_keep)
        ids_masked = torch.stack(ids_masked, dim=0)  # (B, n_mask)
        ids_restore = torch.stack(ids_restore, dim=0) # (B, N)

        # 4. Gather kept tokens
        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))

        return x_masked, ids_keep, ids_masked, ids_restore
    
    def tube_masking(self, x, mask_ratio=0.75):
        """
        MAE-style tube masking.
        x: (B, N, D), where N = T * H_patches * W_patches
        """
        B, N, D = x.shape
        T = self.input_T
        H_patches = self.input_H // self.patch_size
        W_patches = self.input_W // self.patch_size
        N_spatial = H_patches * W_patches

        # Number of spatial patches to keep
        n_keep_spatial = int(N_spatial * (1 - mask_ratio))

        ids_keep = []
        ids_masked = []
        ids_restore = []

        for b in range(B):
            # 1. Randomly permute spatial indices
            perm_spatial = torch.randperm(N_spatial, device=x.device)
            
            # 2. Select the spatial locations to keep
            keep_spatial = perm_spatial[:n_keep_spatial]
            masked_spatial = perm_spatial[n_keep_spatial:]

            # 3. Get the full spatio-temporal indices for kept and masked tubes
            keep_full = []
            for s_idx in keep_spatial:
                start_idx = s_idx * T
                end_idx = start_idx + T
                keep_full.append(torch.arange(start_idx, end_idx, device=x.device))
            
            masked_full = []
            for s_idx in masked_spatial:
                start_idx = s_idx * T
                end_idx = start_idx + T
                masked_full.append(torch.arange(start_idx, end_idx, device=x.device))
                
            ids_keep.append(torch.cat(keep_full, dim=0))
            ids_masked.append(torch.cat(masked_full, dim=0))

            # 4. Build restore index (inverse of permutation)
            full_perm = torch.cat([ids_keep[-1], ids_masked[-1]], dim=0)
            ids_restore_b = torch.empty_like(full_perm)
            ids_restore_b[full_perm] = torch.arange(N, device=x.device)
            ids_restore.append(ids_restore_b)

        ids_keep = torch.stack(ids_keep, dim=0)
        ids_masked = torch.stack(ids_masked, dim=0)
        ids_restore = torch.stack(ids_restore, dim=0)
        
        # 5. Gather kept tokens
        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
        
        return x_masked, ids_keep, ids_masked, ids_restore




    def forward(self, x):
        B, T, C, H, W = x.shape # (B, T, C, H, W)
        
        # 1. Patchify input video
        x_patched = self.patchify(x)  # (B, N, D)
        print('x_patched',x_patched.shape) # (B,N,T,D)
        
        N = x_patched.shape[1]

        # 2. Mask patches
        x_masked, ids_keep, ids_masked, ids_restore = self.random_masking(x_patched, self.mask_ratio)
        
        # print('x_masked',x_masked.shape) # (B,N,T,D)
        ids_keep = ids_keep.long() # Ids of unmasked
        ids_masked = ids_masked.long()
        ids_restore = ids_restore.long()

        # Calculate masked patch indices
        all_ids = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)  # (B, N)
        mask = torch.ones_like(all_ids, dtype=torch.bool)
        mask.scatter_(1, ids_keep, False)
        
        pt = T // self.tubelet_size 
        ph = pw = int((self.unmasked_patches // pt) ** 0.5)
        print('ph,pw,pt',ph,pw,pt)
        assert ph * pw * pt == self.unmasked_patches, "Patch grid mismatch"

        # We know the tube shape is (pt, square_size, square_size)
        x_masked_video = self.unpatchify(x_masked, (pt, ph, pw))  # (B, C, T, H_mask, W_mask)
        print('x_masked_unpatchify',x_masked_video.shape)
        # x_masked_video = x_masked_video.permute(0,2,1,3,4)  # (B,C,T,H,W)
        
        
        # 4. Encoder forward on masked video        
        x_masked_video = x_masked_video.permute(0,2,1,3,4) # (B,T,C,H,W)
        outputs = self.encoder(x_masked_video, output_hidden_states=True)
        tokens = outputs.last_hidden_state[:,1:,:]  # tuple of all hidden layers
        
        # Group first
        tokens = tokens.view(B, ph, pw,pt, self.hparams.embed_dim)  # (B, 16, 8, 8, D)
        # tokens = tokens.view(B, ph, pw, T, self.hparams.embed_dim)  # (B, 16, 8, 8, D)
        tokens = tokens.permute(0, 3, 1, 2, 4).contiguous()         # (B, ph, pw, T, D)  <-- matches your compact cube order


        # 5. Embed encoder features to decoder_dim
        x_vis = self.decoder_embed(tokens)  # (B, n_visible, decoder_dim)
        
        x_vis = x_vis.view(B,-1,self.hparams.decoder_dim)
        # print(x_vis.shape)

        # 6. Prepare mask tokens for masked patches
        # print(ids_masked.shape[1])

        mask_tokens = self.mask_token.expand(B, ids_masked.shape[1], -1)  # (B, n_masked, decoder_dim)
        # x_ = torch.cat([x_vis, mask_tokens], dim=1)   # (B, N, D) but shuffled
        # x_dec = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x_vis.shape[2]))

        # print(mask_tokens.shape)
        # 7. Create full sequence tensor for decoder input
        # Restore to original order
        x_ = torch.cat([x_vis, mask_tokens], dim=1)  # (B, n_keep + n_masked, D)
        x_dec = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x_.shape[2]))
        # 8. Add positional embedding
        x_dec = x_dec + self.decoder_pos_embed
    


        # 9. Decode full sequence
        x_dec = self.decoder_transformer(x_dec)
        # print('x_dec',x_dec.shape)
        pred = self.decoder_pred(x_dec)  # (B, N, patch_dim)
        
        # pred_masked = pred.clone()
        # pred_masked[mask == 0] = 0  # zero out unmasked tokens
        
        # 3. Unpatchify visible patches to video for encoder
        ph = pw = int((self.N // pt) ** 0.5)
        recon = self.unpatchify(pred, (pt, ph, pw))  # (B, C, T, H, W)
        recon = recon.permute(0, 2, 1, 3, 4)  # (B, T, C, H, W)

        return recon, x_masked_video, mask, ids_masked, pred, x_patched

    def training_step(self, batch, batch_idx):
        x, y = batch  # (B, 1, T, H, W)
        B = x.shape[0]

        recon, x_masked, mask, ids_masked, pred, target = self(x)

        # Compute loss only on masked patches
        mask = mask.bool()
        
        # Expand mask for (B, N, D)
        mask_expanded = mask.unsqueeze(-1).expand_as(pred)  # (5, 4096, 64)

        # Apply masking
        pred_masked   = pred[mask_expanded]    # (num_masked * D,)
        target_masked = target[mask_expanded]  # (num_masked * D,)


        # Compute loss only on masked patches
        loss = self.criterion(pred_masked, target_masked)
        self.log("train_loss", loss, prog_bar=True, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        recon, _, mask, _, _, _ = self(x)
        loss = self.criterion(recon, x)
        self.log('val_loss', loss, prog_bar=True)
        # Save first batch to visualize later
        if batch_idx == 0:
            self.val_batch_for_viz = (x, mask, recon)
        return loss

    def on_validation_epoch_end(self):
        if hasattr(self, 'val_batch_for_viz'):
            x, mask, recon = self.val_batch_for_viz
            visualize_reconstruction(x, recon, sample_idx=0, num_frames=16)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-6)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        return [optimizer], [scheduler]


model = MAEPretrain()
trainer = pl.Trainer(
    max_epochs=200,
    accelerator='auto',
    log_every_n_steps=20,
    check_val_every_n_epoch=50,
)
trainer.fit(model, train_loader, train_loader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                | Type               | Params | Mode 
-------------------------------------------------------------------
0 | encoder             | SwinTransformer3d  | 28.2 M | train
1 | decoder_embed       | Linear             | 393 K  | train
2 | decoder_transformer | TransformerEncoder | 6.3 M  | train
3 | decoder_pred        | Sequential         | 131 K  | train
  | other params   

Total patches: 1568
Target unmasked patches: 392.0
Unmasked_patches: 256
Actual patches used: 392/1568 : 0.25
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]x_patched torch.Size([10, 512, 1536])
ph,pw,pt 7 7 8
x_masked_unpatchify torch.Size([10, 6, 8, 112, 112])


/home/ubuntu/miniconda3/envs/dion/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/ubuntu/miniconda3/envs/dion/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


TypeError: SwinTransformer3d.forward() got an unexpected keyword argument 'output_hidden_states'