# I-JEPA

The relevant needed code snippets defining the model architecture and such things have been copied from the STRL_training notebook. The relevant part is at the very bottom of this notebook where the inference loop is defined.

## Imports

In [1]:
# core train
import os
import copy
import sys
import numpy as np
import torch
import torch.nn.functional as F
import yaml

## Logger

In [2]:
import logging
log_timings = True
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

## Config

In [3]:
# Basic global config
_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

In [4]:
args = {
    "data": {
        "batch_size": 4,
        "test_folders": ['/kaggle/input/fmri-train-2-norm-v3/data/noisy_func_train_2.npy'],
        "num_workers": 2,
        "pin_mem": True,
        "root_path": "/kaggle/input",
        "use_horizontal_flip": False
    },
    "logging": {
        "folder": "/kaggle/working/logs",
        "write_tag": "jepa"
    },
    "mask": {
        "allow_overlap": False,
        "aspect_ratio": [0.75, 1.5],
        "enc_mask_scale": [0.85, 1.0],
        "min_keep": 10,
        "num_enc_masks": 1,
        "num_pred_masks": 1,
        "patch_size": 16,
        "pred_mask_scale": [0.15, 0.2]
    },
    "meta": {
        "copy_data": False,
        "load_checkpoint": True,
        "model_name": "vit_small",
        "pred_depth": 12,
        "pred_emb_dim": 384,
        "read_checkpoint": "/kaggle/input/strl-jepa/pytorch/default/1/logs/jepa-latest.pth.tar",
        "use_bfloat16": True
    },
    "optimization": {
        "ema": [0.996, 1.0],
        "epochs": 25,
        "final_lr": 1.0e-5,
        "final_weight_decay": 0.4,
        "ipe_scale": 1.0,
        "lr": 0.001,
        "start_lr": 0.0002,
        "warmup": 50,
        "weight_decay": 0.04
    }
}


In [5]:
resume_preempt = False
rank = 0

# -- META
use_bfloat16 = args['meta']['use_bfloat16']
model_name = args['meta']['model_name']
load_model = args['meta']['load_checkpoint'] or resume_preempt
r_file = args['meta']['read_checkpoint']
copy_data = args['meta']['copy_data']
pred_depth = args['meta']['pred_depth']
pred_emb_dim = args['meta']['pred_emb_dim']
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- DATA
batch_size = args['data']['batch_size']
pin_mem = args['data']['pin_mem']
num_workers = args['data']['num_workers']
root_path = args['data']['root_path']
test_folders = args['data']['test_folders']
# --

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --

# -- OPTIMIZATION
ema = args['optimization']['ema']
ipe_scale = args['optimization']['ipe_scale']  # scheduler scale factor (def: 1.0)
wd = float(args['optimization']['weight_decay'])
final_wd = float(args['optimization']['final_weight_decay'])
num_epochs = args['optimization']['epochs']
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']

# -- LOGGING
folder = args['logging']['folder']
tag = args['logging']['write_tag']

os.makedirs(folder, exist_ok=True)
dump = os.path.join(folder, 'params-ijepa.yaml')
with open(dump, 'w') as f:
    yaml.dump(args, f)

## Dataset creation and preprocessing

### Data transformation

#### Masking

In [6]:
import math
from multiprocessing import Value

In [7]:
class TimeSeriesMaskCollator:
    def __init__(self, num_frames=300, frame_size=16, nenc=1, npred=1):
        self.num_frames = num_frames
        self.frame_size = frame_size
        self._itr_counter = Value('i', -1)
        self.npred = 1
        self.nenc = 1

    def step(self):
        i = self._itr_counter
        with i.get_lock():
            i.value += 1
            v = i.value
        return v
    
    def collate_merge_batches(self, batch):
        merged = torch.cat([item for item in batch], dim=0)
        return merged

    def _sample_frame_mask(self, frame_idx):
        mask = torch.zeros(self.num_frames, dtype=torch.int32)
        mask[frame_idx] = 1
        mask = torch.nonzero(mask.flatten()).squeeze()
        return mask, frame_idx

    def build_encoder_mask_from_pred(self, pred_masks):
        enc_mask = torch.ones(self.num_frames, dtype=torch.int32)
        for pred_mask in pred_masks:
            enc_mask[pred_mask] = 0  # Zero out the masked regions
        return torch.nonzero(enc_mask.flatten()).squeeze()

    def __call__(self, batch, frame_idx=None):
        collated_batch = self.collate_merge_batches(batch)
        B = collated_batch.shape[0]

        seed = self.step()
        g = torch.Generator()
        g.manual_seed(seed)

        collated_masks_enc, collated_masks_pred = [], []

        for _ in range(B):
            pred_masks = []
            pred_frame_idxs = []
            for _ in range(self.npred):
                mask, idx = self._sample_frame_mask(frame_idx = frame_idx)
                pred_masks.append(mask)
                pred_frame_idxs.append(idx)

            collated_masks_pred.append(pred_masks)

            enc_masks = []
            for _ in range(self.nenc):
                enc_mask = self.build_encoder_mask_from_pred(pred_masks)
                enc_masks.append(enc_mask)

            collated_masks_enc.append(enc_masks)

        collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
        collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

        return collated_batch, collated_masks_enc, collated_masks_pred


In [8]:
class CollatorWrapper:
    def __init__(self, collator):
        self.collator = collator
        self.frame_idx = 0

    def set_frame_idx(self, idx):
        self.frame_idx = idx

    def __call__(self, batch):
        return self.collator(batch, frame_idx=self.frame_idx)

In [9]:
def apply_masks(x, masks):
    """
    :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
    :param masks: list of tensors containing indices of patches in [N] to keep
    """
    all_x = []
    for m in masks:
        mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)

In [10]:
mask_collator = TimeSeriesMaskCollator() # defaults to 300 frames of size 16x16
collate_wrapper = CollatorWrapper(mask_collator)

#### Image transforms

In [11]:
import numpy as np
import torch # Import PyTorch

class Compose:
    """
    Composes several transforms together.
    Args:
        transforms (list of callables): list of transforms to compose.
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img_array):
        """
        Applies the composed transforms to the input array.
        The input can be a NumPy array. The output will be a torch.Tensor
        if torch.from_numpy is the last transform.
        Args:
            img_array (numpy.ndarray): Input image array (C, H, W).
        Returns:
            torch.Tensor: Transformed image tensor.
        """
        for t in self.transforms:
            img_array = t(img_array) # Note: img_array will become a torch.Tensor at the end
        return img_array


# Build the transform list
transform_list = []

def reshape_to_2d_timeline(patch_time_series):
    # Convert to PyTorch tensor immediately for permute/reshape
    patch_time_series_tensor = torch.from_numpy(patch_time_series).float()
    height_patch = patch_time_series_tensor.shape[0]  # H
    width_patch = patch_time_series_tensor.shape[1]   # W
    total_time_steps = patch_time_series_tensor.shape[2]  # T   

    # print(patch_time_series_tensor.shape)

    
    # Apply the corrected reshaping logic here
    # Original tensor shape (H, W, T) -> (16, 16, total_time_steps)
    # Permute to (H, T, W)
    intermediate_tensor = patch_time_series_tensor.permute(0, 2, 1)
    
    sample = intermediate_tensor.reshape(height_patch, total_time_steps * width_patch)

    # Add a channel dimension: (1, H, T*W)
    sample = sample.unsqueeze(0)
    return sample
    #return patch_time_series_tensor

transform_list += [reshape_to_2d_timeline]  # Add reshape transform
transform_list += [lambda x: x.unsqueeze(0)] # add batch dimension

# Composition of transforms
transform = Compose(transform_list)

### Dataset & Dataloader

In [12]:
import random

In [13]:
class NoisyDataset(torch.utils.data.Dataset):
    def __init__(self, noisy_images_paths: list, transform=None, normalize=True):
        """Initialize fMRI dataset for denoising with memory-efficient loading,
        extracting 16x16 patches with full time series per depth channel.

        Args:
            noisy_images_paths (list): List of paths to noisy fMRI volumes (.npy files)
            stats: Dictionary containing 'mean' and 'std' for normalization.
            transform (callable, optional): Optional transform to be applied on a sample.
            normalize (bool): Whether to apply channel-wise normalization.
        """
        self.noisy_paths = noisy_images_paths
        self.transform = transform
        self.normalize = normalize

        self.file_slice_mapping = [] # Stores (file_idx, z_idx, patch_y_idx, patch_x_idx)

        patch_size = 16
        dataset_length = 0

        for i, path in enumerate(noisy_images_paths):
            # Load metadata about the file shape without loading full content
            data_shape = np.load(path, mmap_mode='r').shape
            # Assuming data_shape is (H, W, Z, T)
            total_height, total_width, depth_channels, total_time_steps = data_shape

            # Calculate how many patches fit along each spatial dimension
            num_patches_width = total_width // patch_size
            num_patches_height = total_height // patch_size

            # Iterate over all depth channels (Z)
            for z_idx in range(depth_channels):
                # Iterate through the spatial grid of patches
                for patch_y_idx in range(num_patches_height):
                    for patch_x_idx in range(num_patches_width):
                        # Each combination of (file, z_idx, patch_y_idx, patch_x_idx) is a unique item
                        self.file_slice_mapping.append((i, z_idx, patch_y_idx, patch_x_idx))
                        dataset_length += 1

        self.data_len = dataset_length
        self.patch_size = patch_size # Store patch_size for __getitem__

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        # Use the mapping to determine which file, depth channel, and spatial patch to load
        file_idx, z_idx, patch_y_idx, patch_x_idx = self.file_slice_mapping[index]

        # Load data from the specific file
        noisy_file_path = self.noisy_paths[file_idx]

        # Load the full 4D array with mmap_mode to avoid loading everything into RAM
        noisy_volume = np.load(noisy_file_path, mmap_mode='r')

        # Calculate the starting and ending coordinates for the current patch
        start_h = patch_y_idx * self.patch_size
        end_h = start_h + self.patch_size
        start_w = patch_x_idx * self.patch_size
        end_w = start_w + self.patch_size

        # Extract the 16x16 patch for the specific depth channel (z_idx)
        # and include the entire time series.
        # The resulting shape will be (patch_size, patch_size, time_steps)
        patch_time_series = noisy_volume[start_h:end_h, start_w:end_w, z_idx, :].copy()

        # patch = patch_time_series[:,:,0]
        # patch_sum = np.sum(patch)
        # if patch_sum < -25:
        #     indices_to_zero_out.append(index)
            

        # Apply transformations if specified
        if self.transform is not None:
            # Your transform should expect a tensor of shape (C, H, W, T) or (H, W, T)
            # depending on how you structure it. Adjust accordingly.
            sample = self.transform(patch_time_series)

        return sample

In [14]:
test_dataset = NoisyDataset(noisy_images_paths=test_folders, transform=transform, normalize=True)

In [16]:
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    collate_fn=collate_wrapper,
    batch_size=batch_size,
    drop_last=False,
    pin_memory=pin_mem,
    num_workers=num_workers,
    persistent_workers=False)

### Inference

In [17]:
import math
from functools import partial
import numpy as np
import torch
import torch.nn as nn

### Model backbone (vision transformer)

In [18]:
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def repeat_interleave_batch(x, B, repeat):
    N = len(x) // B
    x = torch.cat([
        torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
        for i in range(N)
    ], dim=0)
    return x

In [19]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, timeline_pixel_width=16*300, patch_size=16, in_chans=1, embed_dim=768):
        super().__init__()
        num_patches = timeline_pixel_width // patch_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=1, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)


class VisionTransformerPredictor(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        num_patches,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=6,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # --
        self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
                                                requires_grad=False)
        predictor_pos_embed = get_1d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
                                                      int(num_patches),
                                                      cls_token=False)
        self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
        # --
        self.predictor_blocks = nn.ModuleList([
            Block(
                dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
        # ------
        self.init_std = init_std
        trunc_normal_(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks_x, masks):
        assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        B = len(x) // len(masks_x)

        # -- map from encoder-dim to pedictor-dim
        x = self.predictor_embed(x)

        # -- add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, N_ctxt, D = x.shape

        # -- concat mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
        # --
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        # --
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- fwd prop
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- return preds for mask tokens
        x = x[:, N_ctxt:]
        x = self.predictor_proj(x)

        return x


class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        timeline_pixel_width=(16*300),
        patch_size=16,
        in_chans=1,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        # --
        self.patch_embed = PatchEmbed(
            timeline_pixel_width=timeline_pixel_width,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # --
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_1d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # --
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ------
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks=None):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed

        # -- mask x
        if masks is not None:
            x = apply_masks(x, masks)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        return x

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_predictor(**kwargs):
    model = VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)
    return model


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_huge(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_giant(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}

### Helper functions

In [20]:
def load_checkpoint(
    device,
    r_path,
    encoder,
    predictor,
    target_encoder,
    excluded_layers = None
):
    epoch = 0
    checkpoint = torch.load(r_path, map_location=torch.device('cpu'))

    # -- loading encoder with filtering
    pretrained_dict = checkpoint['encoder']
    
    # Remove 'module.' prefix if it exists
    new_pretrained_dict = {}
    for k, v in pretrained_dict.items():
        if k.startswith('module.'):
            new_pretrained_dict[k[7:]] = v  # Remove 'module.' prefix
        else:
            new_pretrained_dict[k] = v
    pretrained_dict = new_pretrained_dict
    
    # Apply excluded_layers filtering
    if excluded_layers != None:
        filtered_dict = {k: v for k, v in pretrained_dict.items() 
                        if not any(keyword in k for keyword in excluded_layers)}
        print(f"Excluded {len(pretrained_dict) - len(filtered_dict)} parameters containing: {excluded_layers}")
        logger.info(f"Excluded {len(pretrained_dict) - len(filtered_dict)} parameters containing: {excluded_layers}")
        pretrained_dict = filtered_dict
    msg = encoder.load_state_dict(pretrained_dict, strict=False)
    print(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')
    logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

    # -- loading predictor
    pretrained_dict = checkpoint['predictor']
    
    # Remove 'module.' prefix if it exists
    new_pretrained_dict = {}
    for k, v in pretrained_dict.items():
        if k.startswith('module.'):
            new_pretrained_dict[k[7:]] = v  # Remove 'module.' prefix
        else:
            new_pretrained_dict[k] = v
    pretrained_dict = new_pretrained_dict

    # Apply excluded_layers filtering
    if excluded_layers != None:
        filtered_dict = {k: v for k, v in pretrained_dict.items() 
                        if not any(keyword in k for keyword in excluded_layers)}
        print(f"Excluded {len(pretrained_dict) - len(filtered_dict)} parameters containing: {excluded_layers}")
        logger.info(f"Excluded {len(pretrained_dict) - len(filtered_dict)} parameters containing: {excluded_layers}")
        pretrained_dict = filtered_dict

    msg = predictor.load_state_dict(pretrained_dict, strict=False)
    logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}')

    # -- loading target_encoder
    if target_encoder is not None:
        pretrained_dict = checkpoint['target_encoder']
        
        # Remove 'module.' prefix if it exists
        new_pretrained_dict = {}
        for k, v in pretrained_dict.items():
            if k.startswith('module.'):
                new_pretrained_dict[k[7:]] = v  # Remove 'module.' prefix
            else:
                new_pretrained_dict[k] = v
        pretrained_dict = new_pretrained_dict
        
        if excluded_layers != None:
            filtered_dict = {k: v for k, v in pretrained_dict.items() 
                            if not any(keyword in k for keyword in excluded_layers)}
            logger.info(f"Excluded {len(pretrained_dict) - len(filtered_dict)} parameters containing: {excluded_layers}")
            pretrained_dict = filtered_dict

        msg = target_encoder.load_state_dict(pretrained_dict, strict=False)
        logger.info(f'loaded pretrained target_encoder from epoch {epoch} with msg: {msg}')


    return encoder, predictor, target_encoder



def init_model(
    device,
    patch_size=16,
    model_name='vit_base',
    timeline_pixel_width=16*300,
    pred_depth=6,
    pred_emb_dim=384
):
    encoder = vit_small(
        timeline_pixel_width=timeline_pixel_width,
        patch_size=patch_size)
    predictor = vit_predictor(
        num_patches=encoder.patch_embed.num_patches,
        embed_dim=encoder.embed_dim,
        predictor_embed_dim=pred_emb_dim,
        depth=pred_depth,
        num_heads=encoder.num_heads)

    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)

    for m in encoder.modules():
        init_weights(m)

    for m in predictor.modules():
        init_weights(m)

    encoder.to(device)
    predictor.to(device)
    logger.info(encoder)
    return encoder, predictor


### Initializing model

In [21]:
# -- init model
encoder, predictor = init_model(
    device=device,
    patch_size=patch_size,
    timeline_pixel_width=300*16,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_name=model_name)
target_encoder = copy.deepcopy(encoder)

In [22]:
# -- load training checkpoint
if load_model:
    encoder, predictor, target_encoder = load_checkpoint(
        device=device,
        r_path=r_file,
        encoder=encoder,
        predictor=predictor,
        target_encoder=target_encoder,
        excluded_layers = None)
    # for _ in range(start_epoch*ipe):
    #     scheduler.step()
    #     wd_scheduler.step()
    #     next(momentum_scheduler)
    #     mask_collator.step()

loaded pretrained encoder from epoch 0 with msg: <All keys matched successfully>


### Inference loop

In [23]:
def validate():
    """Run validation on the validation dataset and return grouped z tensors."""
    logger.info('Running validation...')
    
    # Set models to eval mode
    encoder.eval()
    predictor.eval()
    target_encoder.eval()

    num_epochs = 300
    group_size = 4
    all_zs = [[] for _ in range(num_epochs)]

    with torch.no_grad():
        for epoch in range(num_epochs):
            if epoch%10==0:
                print(f"epochs: {epoch+1}/300")
            collate_wrapper.set_frame_idx(epoch)
            for itr, (udata, masks_enc, masks_pred) in enumerate(test_loader):
                # if itr > 20:
                #     break
                # Load and process images
                imgs = udata.to(device, non_blocking=True)
                indices_to_zero = []
                for i in range(imgs.shape[0]):
                    slice_2d = imgs[i, 0, :, :16]  # First 16x16 slice
                    value_sum = slice_2d.sum().item()
                    if value_sum < -25:
                        indices_to_zero.append(i)
                
                masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
                masks_2 = [u.unsqueeze(-1).to(device, non_blocking=True) for u in masks_pred]
                
                # Forward pass
                def val_step():
                    with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):
                        # # Target encoding
                        # h = target_encoder(imgs)
                        # h = F.layer_norm(h, (h.size(-1),))
                        # B = len(h)
                        # h = apply_masks(h, masks_2)
                        # h = repeat_interleave_batch(h, B, repeat=len(masks_1))
                        
                        # Context encoding and prediction
                        z = encoder(imgs, masks_1)
                        B, T, F = z.shape  # B=4, T=299, F=384
                        group_size = 299
                        num_groups = T // group_size  # 299 // 13 = 23
                        
                        # Trim to fit exact number of groups
                        z_trimmed = z[:, :num_groups * group_size, :]  # Shape: (4, 299 -> 299, 384) → then trimmed to (4, 299, 384)
                        
                        # Reshape and average
                        z_avg = z_trimmed.view(B, num_groups, group_size, F).mean(dim=2)  # Shape: (4, 23, 384)
                        for idx in indices_to_zero:
                            # print(f"{idx} set to zero")
                            z_avg[idx] = 0
                        # print(z_avg)
                        return z_avg
    
                z = val_step()
                all_zs[epoch].append(z.cpu())

    # Stack each epoch's results
    epoch_tensors = []
    for epoch_zs in all_zs:
        z_epoch_tensor = torch.stack(epoch_zs)  # shape (num_batches_epoch, B, ...)
        num_batches = z_epoch_tensor.size(0)
        num_groups = num_batches // group_size
        z_epoch_tensor = z_epoch_tensor[:num_groups * group_size]  # Trim
        z_epoch_tensor = z_epoch_tensor.view(num_groups, group_size, *z_epoch_tensor.shape[1:])
        epoch_tensors.append(z_epoch_tensor)

    # Final tensor shape: (num_epochs=299, num_groups, group_size, B, ...)
    z_grouped = torch.stack(epoch_tensors)

    logger.info(f"Validation completed. Grouped tensor shape: {z_grouped.shape}")
    return z_grouped


In [24]:
features = validate()

epochs: 1/300


  with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):


epochs: 11/300
epochs: 21/300
epochs: 31/300
epochs: 41/300
epochs: 51/300
epochs: 61/300
epochs: 71/300
epochs: 81/300
epochs: 91/300
epochs: 101/300
epochs: 111/300
epochs: 121/300
epochs: 131/300
epochs: 141/300
epochs: 151/300
epochs: 161/300
epochs: 171/300
epochs: 181/300
epochs: 191/300
epochs: 201/300
epochs: 211/300
epochs: 221/300
epochs: 231/300
epochs: 241/300
epochs: 251/300
epochs: 261/300
epochs: 271/300
epochs: 281/300
epochs: 291/300


In [25]:
features.shape

torch.Size([300, 156, 4, 4, 1, 384])

In [26]:
features = features.squeeze(-2)

In [27]:
features.shape

torch.Size([300, 156, 4, 4, 384])

In [28]:
import numpy as np
np.save("temporal_features_1.npy", features.numpy())