Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
# @title Install dependency

!pip install mediapy thop

In [None]:
# @title Imports

import functools
import math

import einops
import numpy as np
import torch
import jax
import jax.numpy as jnp

### VideoMAE from Github

In [None]:
# @title Download VideoMAE code

%cd /content
!git clone https://github.com/MCG-NJU/VideoMAE.git

In [None]:
# @title Define VideoMAE model

%cd /content/VideoMAE

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from functools import partial
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

class Mlp(nn.Module):
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features, out_features)
    self.drop = nn.Dropout(drop)

  def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    # x = self.drop(x)
    # commit this for the orignal BERT implement
    x = self.fc2(x)
    x = self.drop(x)
    return x

class Attention(nn.Module):
  def __init__(
      self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
      proj_drop=0., attn_head_dim=None):
    super().__init__()
    self.num_heads = num_heads
    head_dim = dim // num_heads
    if attn_head_dim is not None:
      head_dim = attn_head_dim
    all_head_dim = head_dim * self.num_heads
    self.scale = qk_scale or head_dim ** -0.5

    self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
    if qkv_bias:
      self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
      self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
    else:
      self.q_bias = None
      self.v_bias = None

    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(all_head_dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)

  def forward(self, x):
    B, N, C = x.shape
    qkv_bias = None
    if self.q_bias is not None:
      qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
    # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
    qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

    q = q * self.scale
    attn = (q @ k.transpose(-2, -1))


    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

class Block(nn.Module):

  def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
               drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
               attn_head_dim=None):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(
      dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
      attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    if init_values > 0:
      self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
      self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
    else:
      self.gamma_1, self.gamma_2 = None, None

  def forward(self, x):
    if self.gamma_1 is None:
      x = x + self.drop_path(self.attn(self.norm1(x)))
      x = x + self.drop_path(self.mlp(self.norm2(x)))
    else:
      x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
      x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
    return x

class PatchEmbed(nn.Module):
  """ Image to Patch Embedding
  """
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
    super().__init__()
    img_size = to_2tuple(img_size)
    patch_size = to_2tuple(patch_size)
    self.tubelet_size = int(tubelet_size)
    num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
                          kernel_size = (self.tubelet_size,  patch_size[0],patch_size[1]),
                          stride=(self.tubelet_size,  patch_size[0],  patch_size[1]))

  def forward(self, x, **kwargs):
    B, C, T, H, W = x.shape
    # FIXME look at relaxing size constraints
    # assert H == self.img_size[0] and W == self.img_size[1], \
    #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
  ''' Sinusoid position encoding table '''
  # TODO: make it with torch instead of numpy
  def get_position_angle_vec(position):
    return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

  return  torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)

class PretrainVisionTransformerEncoder(nn.Module):
  """ Vision Transformer with support for patch or hybrid CNN input stage
  """
  def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
               num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
               drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False,
               use_learnable_pos_emb=False):
    super().__init__()
    self.num_classes = num_classes
    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
    self.patch_embed = PatchEmbed(
      img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tubelet_size=tubelet_size)
    num_patches = self.patch_embed.num_patches
    self.embedding_shape = (8, 14, 14)  # Manually added for resizing position embedding
    self.use_checkpoint = use_checkpoint

    # TODO: Add the cls token
    if use_learnable_pos_emb:
      self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
    else:
      # sine-cosine positional embeddings
      self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
    self.blocks = nn.ModuleList([
      Block(
        dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
        init_values=init_values)
      for i in range(depth)])
    self.norm = norm_layer(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

  def interpolate_pos_encoding(self, x, h, w):
    x = x.reshape(self.embedding_shape + (-1,))
    dim = x.shape[-1]
    x = F.interpolate(
      x.permute(0, 3, 1, 2),
      scale_factor=(h / self.embedding_shape[-2], w / self.embedding_shape[-1]),
      mode="bicubic",
    )
    x = x.permute(0, 2, 3, 1).view(1, -1, dim)
    return x

  def forward_features(self, x):
    _, _, T, h, w = x.shape
    h = h // self.patch_embed.patch_size[0]
    w = w // self.patch_embed.patch_size[1]
    x = self.patch_embed(x)

    # x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
    pos_embed = self.pos_embed.to(x.device)
    x = x + self.interpolate_pos_encoding(pos_embed, h, w)

    B, _, C = x.shape
    x_vis = x.reshape(B, -1, C) # ~mask means visible

    if self.use_checkpoint:
      for blk in self.blocks:
        x_vis = checkpoint.checkpoint(blk, x_vis)
    else:
      for blk in self.blocks:
        x_vis = blk(x_vis)

    x_vis = self.norm(x_vis)
    return x_vis

  def forward(self, x):
    x = self.forward_features(x)
    # x = self.head(x)
    return x

class PretrainVisionTransformerDecoder(nn.Module):
  """ Vision Transformer with support for patch or hybrid CNN input stage
  """
  def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
               qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
               norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False
               ):
    super().__init__()
    self.num_classes = num_classes
    assert num_classes == 3 * tubelet_size * patch_size ** 2
    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
    self.patch_size = patch_size
    self.use_checkpoint = use_checkpoint

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
    self.blocks = nn.ModuleList([
      Block(
        dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
        init_values=init_values)
      for i in range(depth)])
    self.norm = norm_layer(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    self.apply(self._init_weights)

  def _init_weights(self, m):
    if isinstance(m, nn.Linear):
      nn.init.xavier_uniform_(m.weight)
      if isinstance(m, nn.Linear) and m.bias is not None:
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
      nn.init.constant_(m.bias, 0)
      nn.init.constant_(m.weight, 1.0)

  def get_num_layers(self):
    return len(self.blocks)

  @torch.jit.ignore
  def no_weight_decay(self):
    return {'pos_embed', 'cls_token'}

  def get_classifier(self):
    return self.head

  def reset_classifier(self, num_classes, global_pool=''):
    self.num_classes = num_classes
    self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

  def forward(self, x, return_token_num):
    if self.use_checkpoint:
      for blk in self.blocks:
        x = checkpoint.checkpoint(blk, x)
    else:
      for blk in self.blocks:
        x = blk(x)

    if return_token_num > 0:
      x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
    else:
      x = self.head(self.norm(x))

    return x

class PretrainVisionTransformer(nn.Module):
  """ Vision Transformer with support for patch or hybrid CNN input stage
  """
  def __init__(self,
               img_size=224,
               patch_size=16,
               encoder_in_chans=3,
               encoder_num_classes=0,
               encoder_embed_dim=768,
               encoder_depth=12,
               encoder_num_heads=12,
               decoder_num_classes=1536, #  decoder_num_classes=768,
               decoder_embed_dim=512,
               decoder_depth=8,
               decoder_num_heads=8,
               mlp_ratio=4.,
               qkv_bias=False,
               qk_scale=None,
               drop_rate=0.,
               attn_drop_rate=0.,
               drop_path_rate=0.,
               norm_layer=nn.LayerNorm,
               init_values=0.,
               use_learnable_pos_emb=False,
               use_checkpoint=False,
               tubelet_size=2,
               num_classes=0, # avoid the error from create_fn in timm
               in_chans=0, # avoid the error from create_fn in timm
               ):
    super().__init__()
    self.encoder = PretrainVisionTransformerEncoder(
      img_size=img_size,
      patch_size=patch_size,
      in_chans=encoder_in_chans,
      num_classes=encoder_num_classes,
      embed_dim=encoder_embed_dim,
      depth=encoder_depth,
      num_heads=encoder_num_heads,
      mlp_ratio=mlp_ratio,
      qkv_bias=qkv_bias,
      qk_scale=qk_scale,
      drop_rate=drop_rate,
      attn_drop_rate=attn_drop_rate,
      drop_path_rate=drop_path_rate,
      norm_layer=norm_layer,
      init_values=init_values,
      tubelet_size=tubelet_size,
      use_checkpoint=use_checkpoint,
      use_learnable_pos_emb=use_learnable_pos_emb)

    # self.decoder = PretrainVisionTransformerDecoder(
    #     patch_size=patch_size,
    #     num_patches=self.encoder.patch_embed.num_patches,
    #     num_classes=decoder_num_classes,
    #     embed_dim=decoder_embed_dim,
    #     depth=decoder_depth,
    #     num_heads=decoder_num_heads,
    #     mlp_ratio=mlp_ratio,
    #     qkv_bias=qkv_bias,
    #     qk_scale=qk_scale,
    #     drop_rate=drop_rate,
    #     attn_drop_rate=attn_drop_rate,
    #     drop_path_rate=drop_path_rate,
    #     norm_layer=norm_layer,
    #     init_values=init_values,
    #     tubelet_size=tubelet_size,
    #     use_checkpoint=use_checkpoint)

    # self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)

    self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

    self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)

  def forward(self, x):
    _, _, T, _, _ = x.shape
    x_vis = self.encoder(x) # [B, N_vis, C_e]
    # x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
    # B, N, C = x_vis.shape
    # we don't unshuffle the correct visible token order,
    # but shuffle the pos embedding accorddingly.
    # expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
    # pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
    # pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
    # x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
    # x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]

    # return x
    return x_vis

def pretrain_videomae_small_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=384,
      encoder_depth=12,
      encoder_num_heads=6,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=192,
      decoder_num_heads=3,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

def pretrain_videomae_base_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=768,
      encoder_depth=12,
      encoder_num_heads=12,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=384,
      decoder_num_heads=6,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

def pretrain_videomae_large_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=1024,
      encoder_depth=24,
      encoder_num_heads=16,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=512,
      decoder_num_heads=8,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

def pretrain_videomae_huge_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=1280,
      encoder_depth=32,
      encoder_num_heads=16,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=640,
      decoder_num_heads=8,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

import time

model = pretrain_videomae_large_patch16_224()

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, 3, num_frames, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    _ = model(inputs)

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### VideoMAE v2

In [None]:
# @title Define VideoMAE v2 model

import transformers

def get_sinusoid_encoding_table(n_position, d_hid):
  def get_angle(pos):
    return [pos / np.power(10000, 2 * (i // 2) / d_hid) for i in range(d_hid)]
  table = np.array([get_angle(i) for i in range(n_position)])
  table[:, 0::2], table[:, 1::2] = np.sin(table[:, 0::2]), np.cos(table[:, 1::2])
  return torch.FloatTensor(table).unsqueeze(0)

class VideoMAEv2Config(transformers.configuration_utils.PretrainedConfig):
  model_type = 'VideoMAEv2_Base'
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

class Mlp(nn.Module):
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features, out_features)

  def forward(self, x):
    return self.fc2(self.act(self.fc1(x)))

class Attention(nn.Module):
  def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None):
    super().__init__()
    self.num_heads = num_heads
    head_dim = attn_head_dim or dim // num_heads
    all_head_dim = head_dim * self.num_heads
    self.scale = qk_scale or head_dim**-0.5
    self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
    if qkv_bias:
      self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
      self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
    else:
      self.q_bias = self.v_bias = None
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(all_head_dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)

  def forward(self, x):
    B, N, C = x.shape
    qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) if self.q_bias is not None else None
    qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    attn = (q * self.scale) @ k.transpose(-2, -1)
    attn = self.attn_drop(attn.softmax(dim=-1))
    x = self.proj((attn @ v).transpose(1, 2).reshape(B, N, -1))
    return self.proj_drop(x)

class Block(nn.Module):
  def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None, cos_attn=False):
    super().__init__()
    self.norm1, self.norm2 = norm_layer(dim), norm_layer(dim)
    self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, attn_head_dim)
    self.drop_path = nn.Identity()
    self.mlp = Mlp(dim, int(dim * mlp_ratio), act_layer=act_layer)
    if init_values > 0:
      self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
      self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
    else:
      self.gamma_1 = self.gamma_2 = None

  def forward(self, x):
    if self.gamma_1 is None:
      x = x + self.drop_path(self.attn(self.norm1(x)))
      x = x + self.drop_path(self.mlp(self.norm2(x)))
    else:
      x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
      x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
    return x

def to_2tuple(x): return (x, x) if not isinstance(x, tuple) else x

class PatchEmbed(nn.Module):
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
    super().__init__()
    img_size, patch_size = to_2tuple(img_size), to_2tuple(patch_size)
    num_spatial = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
    num_patches = num_spatial * (num_frames // tubelet_size)
    self.img_size, self.patch_size, self.num_patches, self.tubelet_size = img_size, patch_size, num_patches, tubelet_size
    self.proj = nn.Conv3d(in_chans, embed_dim, (tubelet_size, patch_size[0], patch_size[1]), (tubelet_size, patch_size[0], patch_size[1]))

  def forward(self, x):
    B, C, T, H, W = x.shape
    return self.proj(x).flatten(2).transpose(1, 2)

class VisionTransformer(nn.Module):
  def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., head_drop_rate=0., norm_layer=nn.LayerNorm, layer_norm_eps=1e-12, init_values=0., use_learnable_pos_emb=False, init_scale=0., num_frames=16, tubelet_size=2, use_mean_pooling=True, with_cp=False, cos_attn=False):
    super().__init__()
    self.num_classes, self.num_features, self.embed_dim, self.tubelet_size = num_classes, embed_dim, embed_dim, tubelet_size
    self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim, num_frames, tubelet_size)
    self.patch_size = patch_size
    num_patches = self.patch_embed.num_patches
    self.embedding_shape = (8, 14, 14)
    self.with_cp = with_cp
    norm_layer = functools.partial(eval(norm_layer), eps=layer_norm_eps)
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if use_learnable_pos_emb else get_sinusoid_encoding_table(num_patches, embed_dim)
    self.pos_drop = nn.Dropout(p=drop_rate)
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
    self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, dpr[i], norm_layer=norm_layer, init_values=init_values, cos_attn=cos_attn) for i in range(depth)])
    self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
    self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
    self.head_dropout = nn.Dropout(head_drop_rate)
    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
    if use_learnable_pos_emb:
      nn.init.trunc_normal_(self.pos_embed, std=.02)

  def interpolate_pos_encoding(self, x, h, w):
    x = x.reshape(self.embedding_shape + (-1,))
    dim = x.shape[-1]
    x = F.interpolate(
        x.permute(0, 3, 1, 2),
        scale_factor=(h / self.embedding_shape[-2], w / self.embedding_shape[-1]),
        mode="bicubic",
    )
    x = x.permute(0, 2, 3, 1).view(1, -1, dim)
    return x

  def forward(self, x):
    B = x.size(0)
    _, _, _, h, w = x.shape
    x = self.patch_embed(x)
    # x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
    pos_embed = self.pos_embed.type_as(x).to(x.device).clone().detach()
    h = h // self.patch_size
    w = w // self.patch_size
    x = x + self.interpolate_pos_encoding(pos_embed, h, w)
    x = self.pos_drop(x)
    for blk in self.blocks:
      x = blk(x)
    return self.fc_norm(x)


class VideoMAEv2(transformers.PreTrainedModel):
  config_class = VideoMAEv2Config
  def __init__(self, config=None):
    super().__init__(config=config)
    self.model_config = config.model_config
    self.model = VisionTransformer(**self.model_config)

  def forward(self, video):
    return self.model(video)

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

import time

model = VideoMAEv2.from_pretrained('OpenGVLab/VideoMAEv2-Large')

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, 3, num_frames, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    _ = model(inputs)

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### V-JEPA

In [None]:
# @title Download V-JEPA code
%cd /content
!git clone https://github.com/facebookresearch/jepa.git
# !pip install -e jepa
# !pip install timm einops torchcodec torchvision torch==2.6.0
%cd /content/jepa

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

%cd /content/jepa

import src.models.vision_transformer as vit
import time

model = vit.VisionTransformer(
    patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
    qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), img_size=224, num_frames=16, tubelet_size=2,
    uniform_power=False, use_sdpa=False, use_SiLU=False, tight_SiLU=True)
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, 3, num_frames, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    _ = model(inputs)

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### V-JEPA 2

In [None]:
# @title Define V-JEPA 2 model

import transformers

class VJEPA2Model(transformers.VJEPA2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.encoder = transformers.models.vjepa2.modeling_vjepa2.VJEPA2Encoder(config)
    self.predictor = transformers.models.vjepa2.modeling_vjepa2.VJEPA2Predictor(config)

  def forward(self, video):
    encoder_outputs = self.encoder(
        pixel_values_videos=video,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
    )
    sequence_output = encoder_outputs.last_hidden_state

    return sequence_output

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

import time

model = VJEPA2Model.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, num_frames, 3, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    _ = model(inputs)

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### DINO

In [None]:
# @title Download DINO code

%cd /content
!git clone https://github.com/facebookresearch/dino.git
%cd /content/dino

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

%cd /content/dino

import vision_transformer as vits
import time

model = vits.__dict__['vit_large'](patch_size=16, num_classes=0)
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, num_frames, 3, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    for t in range(num_frames):
      _ = model(inputs[:, t])

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### DINO v2

In [None]:
# @title Define DINO v2 model

import transformers
from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings, Dinov2Encoder

class Dinov2Model(transformers.Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.embeddings = Dinov2Embeddings(config)
    self.encoder = Dinov2Encoder(config)
    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

  def forward(self, image):
    embedding_output = self.embeddings(image, bool_masked_pos=None)
    encoder_outputs = self.encoder(
        embedding_output,
        head_mask=None,
    )
    sequence_output = self.layernorm(encoder_outputs[0])
    return sequence_output

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

import time

model = Dinov2Model.from_pretrained("facebook/dinov2-large")
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, num_frames, 3, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    for t in range(num_frames):
      _ = model(inputs[:, t])

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### CropMAE

In [None]:
# @title Download CropMAE code

%cd /content
!git clone https://github.com/alexandre-eymael/CropMAE.git
%cd /content/CropMAE

In [None]:
# @title Load CropMAE model

%cd /content/CropMAE

from timm.models.vision_transformer import Block

class PatchEmbed(nn.Module):
  """ Image to Patch Embedding"""
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
    super().__init__()
    num_patches = (img_size // patch_size) * (img_size // patch_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

  def forward(self, x):
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

class MaskedAutoencoderViT(nn.Module):
  """ Masked Autoencoder with VisionTransformer backbone"""
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, ckpt_path=None):
    super().__init__()

    # --------------------------------------------------------------------------
    # MAE encoder specifics
    self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
    num_patches = self.patch_embed.num_patches

    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

    self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)])
    self.norm = norm_layer(embed_dim)

  def interpolate_pos_encoding(self, x, w, h):
    npatch = x.shape[1] - 1
    N = self.pos_embed.shape[1] - 1
    if npatch == N and w == h:
      return self.pos_embed
    class_pos_embed = self.pos_embed[:, 0]
    patch_pos_embed = self.pos_embed[:, 1:]
    dim = x.shape[-1]
    w0 = w // self.patch_embed.patch_size
    h0 = h // self.patch_embed.patch_size
    # we add a small number to avoid floating point error in the interpolation
    # see discussion at https://github.com/facebookresearch/dino/issues/8
    w0, h0 = w0 + 0.1, h0 + 0.1
    patch_pos_embed = nn.functional.interpolate(
        patch_pos_embed.reshape(
            1, int(math.sqrt(N)), int(math.sqrt(N)), dim
        ).permute(0, 3, 1, 2),
        scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
        mode="bicubic",
    )
    assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

  def forward(self, x):
    B, nc, w, h = x.shape
    x = self.patch_embed(x)  # patch linear embedding

    # add the [CLS] token to the embed patch tokens
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # add positional encoding to each token
    x = x + self.interpolate_pos_encoding(x, w, h)

    for blk in self.blocks:
      x = blk(x)
    x = self.norm(x)
    return x

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

import time

model = MaskedAutoencoderViT(
    patch_size=16, embed_dim=1024, depth=24, num_heads=16,
    mlp_ratio=4, norm_layer=functools.partial(nn.LayerNorm, eps=1e-6))
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, num_frames, 3, 256, 256).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    for t in range(num_frames):
      _ = model(inputs[:, t])

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")

### RVM

In [None]:
# @title Download checkpoint

%mkdir /content/rvm
%cd /content/rvm

!wget https://storage.googleapis.com/dm-tapnet/tmp/pretrain_rvm_large16_256_xid175558463_wid1.npz

In [None]:
# @title Define model {form-width: "20%"}

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import einops
import re
import dataclasses

class PatchEmbedding(nn.Module):
  def __init__(self, patch_size=(16, 16), num_features=1024):
    super().__init__()
    self.patch_size = patch_size
    self.num_features = num_features
    self.Conv_0 = nn.Conv2d(in_channels=3, out_channels=num_features, kernel_size=patch_size, stride=patch_size, padding=0)

  def forward(self, x):
    x = x.permute(0, 3, 1, 2)
    return self.Conv_0(x).permute(0, 2, 3, 1)

def get_mae_sinusoid_encoding_table(n_position, d_hid, dtype=torch.float32):
  """Sinusoid positional encoding table for MAE."""
  def get_position_angle_vec(position):
    return [position / math.pow(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

  return torch.tensor(sinusoid_table, dtype=dtype)[None, ...]

class SincosPosEmb(nn.Module):
  """Returns sinusoidal positional embedding given the shape of the tokens."""
  def __init__(self, base_token_shape=None, use_jax_interpolation=False):
    super().__init__()
    self.base_token_shape = base_token_shape
    self.use_jax_interpolation = use_jax_interpolation

  def forward(self, tokens_shape):
    d = tokens_shape[-1]
    if self.base_token_shape is not None:
      h, w = self.base_token_shape
    else:
      h, w = tokens_shape[-3], tokens_shape[-2]

    posenc = get_mae_sinusoid_encoding_table(h * w, d)  # [1, h*w, d]
    posenc = posenc.view(1, h, w, d)  # [1, h, w, d]

    *b, tokens_h, tokens_w, _ = tokens_shape
    for _ in range(len(b)-1):
      posenc = posenc.expand(*b, -1, -1, -1)

    if tokens_h != h or tokens_w != w:
      if self.use_jax_interpolation:
        posenc = jnp.array(posenc.numpy())
        posenc = jax.image.resize(posenc, (*b, tokens_h, tokens_w, d), method='bicubic')
        posenc = torch.from_numpy(np.array(posenc))
      else:
        posenc = posenc.view(-1, h, w, d)
        posenc = F.interpolate(
          posenc.permute(0, 3, 1, 2),  # [B, D, H, W]
          size=(tokens_h, tokens_w),
          mode='bicubic',
          align_corners=False
        ).permute(0, 2, 3, 1)  # [B, H, W, D]
        posenc = posenc.view(*b, tokens_h, tokens_w, d)

    return posenc.cuda()

class Tokenizer(nn.Module):
  def __init__(self, patch_embedding, posenc):
    super().__init__()
    self.patch_embedding = patch_embedding
    self.posenc = posenc

  def forward(self, x):
    tokens = self.patch_embedding(x)
    # posenc = self.posenc(tokens.shape)
    # tokens += posenc
    return tokens

class TransformerMLP(nn.Module):
  """Simple MLP with a single hidden layer for use in Transformer blocks."""
  def __init__(self, input_dim, hidden_size=None):
    super().__init__()
    self.hidden_size = 4 * input_dim if hidden_size is None else hidden_size
    self.dense_in = nn.Linear(input_dim, self.hidden_size)
    self.dense_out = nn.Linear(self.hidden_size, input_dim)
    nn.init.xavier_uniform_(self.dense_in.weight)
    nn.init.zeros_(self.dense_in.bias)
    nn.init.xavier_uniform_(self.dense_out.weight)
    nn.init.zeros_(self.dense_out.bias)

  def forward(self, x):
    h = F.gelu(self.dense_in(x))
    return self.dense_out(h)

def dot_product_attention_weights(query, key):
  query = query / math.sqrt(query.size(-1))
  attn_weights = torch.einsum('bqhd,bkhd->bhqk', query, key)
  attn_weights = F.softmax(attn_weights, dim=-1)
  return attn_weights

class ImprovedMultiHeadDotProductAttention(nn.Module):
  def __init__(self, embed_dim, num_heads, qk_features=None, v_features=None, out_features=None):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.qk_features = qk_features or embed_dim
    self.v_features = v_features or self.qk_features
    self.out_features = out_features or embed_dim

    # Head dimensions
    self.head_dim_qk = self.qk_features // self.num_heads
    self.head_dim_v = self.v_features // self.num_heads

    # Linear projections
    self.query = nn.Linear(embed_dim, self.qk_features)
    self.key  = nn.Linear(embed_dim, self.qk_features)
    self.value = nn.Linear(embed_dim, self.v_features)

    # Output projection
    self.out = nn.Linear(self.v_features, self.out_features)

  def forward(self, inputs_q, inputs_k=None, inputs_v=None, mask=None):
    batch_size, seq_len_q, _ = inputs_q.shape
    if inputs_k is None:
      inputs_k = inputs_q
    if inputs_v is None:
      inputs_v = inputs_k

    seq_len_k = inputs_k.shape[1]

    # Linear projections and reshape to (batch, seq_len, num_heads, head_dim)
    query = self.query(inputs_q).view(batch_size, seq_len_q, self.num_heads, self.head_dim_qk)
    key   = self.key(inputs_k).view(batch_size, seq_len_k, self.num_heads, self.head_dim_qk)
    value = self.value(inputs_v).view(batch_size, seq_len_k, self.num_heads, self.head_dim_v)

    # Scaled dot-product attention
    query_scaled = query / math.sqrt(self.head_dim_qk)
    attn_weights = torch.einsum('bqhd,bkhd->bhqk', query_scaled, key)
    if mask is not None:
      attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
    attn_weights = F.softmax(attn_weights, dim=-1)

    # Weighted sum over values
    x = torch.einsum('bhqk,bkhd->bqhd', attn_weights, value)
    x = x.reshape(batch_size, seq_len_q, self.num_heads * self.head_dim_v)

    # Output projection
    out = self.out(x)
    return out

class PreNormBlock(nn.Module):
  def __init__(self, attention_norm, mlp_norm, attention, mlp):
    super().__init__()
    self.attention_norm = attention_norm
    self.mlp_norm = mlp_norm
    self.attention = attention
    self.mlp = mlp

  def forward(self, x):
    norm_x = self.attention_norm(x)
    x = x + self.attention(norm_x)
    norm_x = self.mlp_norm(x)
    x = x + self.mlp(norm_x)
    return x

VIT_SIZES = {
    'mu': (32, 1, 128, 2),
    'Ti': (192, 12, 768, 3),
    'S': (384, 12, 1536, 6),
    'M': (512, 12, 2048, 8),
    'B': (768, 12, 3072, 12),
    'L': (1024, 24, 4096, 16),
    'H': (1280, 32, 5120, 16),
    'g': (1408, 40, 6144, 16),
    'G': (1664, 48, 8192, 16),
    'e': (1792, 56, 15360, 16),
}

@dataclasses.dataclass(frozen=True)
class ViTSpec:
  hidden_size: int
  num_layers: int
  mlp_size: int
  num_heads: int
  patch_size: int = None

  @classmethod
  def from_variant_string(cls, variant_str: str):
    r = re.match(r'^([Vv][Ii][Tt][-_])?(?P<name>[a-zA-Z]{1,2})(/(?P<patch>\d+))?$', variant_str)
    if r is None:
      raise ValueError(f'Invalid variant string: {variant_str!r}.')
    name = r.groupdict()['name']
    spec = cls(*VIT_SIZES[name])
    patch_size = r.groupdict()['patch']
    if patch_size is not None:
      spec = dataclasses.replace(spec, patch_size=int(patch_size))
    return spec

class Transformer(nn.Module):
  def __init__(self, num_layers, hidden_size, num_heads, mlp_size, qk_features=None, v_features=None):
    super().__init__()
    self.layers = nn.ModuleList([
        PreNormBlock(
            attention_norm=nn.LayerNorm(hidden_size, eps=1e-06, dtype=torch.float32),
            mlp_norm=nn.LayerNorm(hidden_size, eps=1e-06, dtype=torch.float32),
            attention=ImprovedMultiHeadDotProductAttention(
                embed_dim=hidden_size,
                num_heads=num_heads,
                qk_features=qk_features or hidden_size,
                v_features=v_features or hidden_size,
            ),
            mlp=TransformerMLP(input_dim=hidden_size, hidden_size=mlp_size),
        )
        for _ in range(num_layers)
    ])
    self.LayerNorm_0 = nn.LayerNorm(hidden_size)

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return self.LayerNorm_0(x)

  @classmethod
  def from_variant_str(cls, variant_str: str, **kwargs):
    spec = ViTSpec.from_variant_string(variant_str)
    all_kwargs = dict(
      num_layers=spec.num_layers,
      hidden_size=spec.hidden_size,
      mlp_size=spec.mlp_size,
      num_heads=spec.num_heads,
    )
    all_kwargs.update(kwargs)
    return cls(**all_kwargs)

class CrossAttentionBlock(nn.Module):
  def __init__(self, num_heads, num_feats, mlp_dim, dtype=torch.float32):
    super().__init__()
    self.attention_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)
    self.mlp_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)
    self.ca_attention_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)

    self.attention = ImprovedMultiHeadDotProductAttention(
      embed_dim=num_feats, num_heads=num_heads
    )
    self.ca_attention = ImprovedMultiHeadDotProductAttention(
      embed_dim=num_feats, num_heads=num_heads
    )
    self.mlp = TransformerMLP(input_dim=num_feats, hidden_size=mlp_dim)

  def forward(self, x, x_kv):
    residual = x
    x = x + self.ca_attention(inputs_q=self.ca_attention_norm(x), inputs_k=x_kv, inputs_v=x_kv)
    x = x + self.mlp(self.mlp_norm(x))
    x = x + self.attention(self.attention_norm(x))
    return x

class CrossAttentionTransformer(nn.Module):
  def __init__(self, num_layers, num_heads, num_feats, mlp_dim, dtype=torch.float32):
    super().__init__()
    self.xa_blocks = nn.ModuleList([
      CrossAttentionBlock(num_heads, num_feats, mlp_dim, dtype=dtype)
      for _ in range(num_layers)
    ])
    self.output_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)

  def forward(self, inputs, inputs_kv):
    x = inputs
    for block in self.xa_blocks:
      x = block(x, inputs_kv)
    return self.output_norm(x)

class RandomStateInit(nn.Module):
  """Random, non-learnable state initialization."""

  def __init__(self):
    super().__init__()

  def forward(self, inputs, batch_shape):
    shape = inputs.shape[-2:]
    state = 0 * torch.randn(batch_shape + shape, dtype=inputs.dtype, device=inputs.device)
    return state

class GatedTransformerCore(nn.Module):
  def __init__(self, transformer, initializer, token_dim, state_layer_norm):
    super().__init__()
    self.transformer = transformer
    self.initializer = initializer
    self.token_dim = token_dim
    self.state_layer_norm = state_layer_norm

    self.input_update = nn.Linear(token_dim, token_dim, bias=False)
    self.input_reset = nn.Linear(token_dim, token_dim, bias=False)
    self.state_update = nn.Linear(token_dim, token_dim, bias=False)
    self.state_reset = nn.Linear(token_dim, token_dim, bias=False)

  def forward(self, inputs, state):
    update_gate = F.sigmoid(self.input_update(inputs) + self.state_update(state))
    reset_gate = F.sigmoid(self.input_reset(inputs) + self.state_reset(state))
    h = self.transformer(inputs, inputs_kv=reset_gate * self.state_layer_norm(state))
    output = (1 - update_gate) * state + update_gate * h
    state = output
    return output, state

class VideoSiamMAE(nn.Module):
  """Video Siamese masked autoencoder model."""

  def __init__(self, tokenizer, encoder, rnn_core, latent_emb_dim=384):
    super().__init__()
    self.tokenizer = tokenizer
    self.encoder = encoder
    self.rnn_core = rnn_core
    self.latent_emb_dim = latent_emb_dim

    # cls_token is a learnable parameter
    self.cls_token = nn.Parameter(torch.randn(1, 1, latent_emb_dim) * 0.02)

  def forward(self, frame, state=None):
    # Tokenize input frame
    frame_tokens = self.tokenizer(frame)  # shape [..., h, w, D] expected
    frame_tokens = einops.rearrange(frame_tokens, '... h w d -> ... (h w) d')

    *b, _, _ = frame_tokens.shape
    # Broadcast cls_token across batch
    cls_token = self.cls_token.expand(*b, -1, -1)  # shape [..., 1, D]

    # Concatenate CLS with patch tokens
    frame_tokens = torch.cat([cls_token, frame_tokens], dim=-2)

    # Encode with transformer encoder
    encoded_frame_tokens = self.encoder(frame_tokens)

    # Initialize state if first step
    if state is None:
        # Expect initializer to accept (inputs, batch_shape)
        state = self.rnn_core.initializer(encoded_frame_tokens, batch_shape=(1,))

    # Recurrent core update
    features, state = self.rnn_core(encoded_frame_tokens, state)

    return features, state

model = VideoSiamMAE(
    tokenizer=Tokenizer(
        patch_embedding=PatchEmbedding(patch_size=[16, 16], num_features=1024),
        posenc=SincosPosEmb(base_token_shape=[16, 16]),
    ),
    encoder=Transformer.from_variant_str(variant_str='L'),
    rnn_core=GatedTransformerCore(
        transformer=CrossAttentionTransformer(
            num_layers=4,
            num_heads=16,
            num_feats=1024,
            mlp_dim=4096,
            dtype=torch.float32,
        ),
        initializer=RandomStateInit(),
        token_dim=1024,
        state_layer_norm=nn.LayerNorm(1024, eps=0.0001, bias=False),
    ),
    latent_emb_dim=1024,
)
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Load checkpoint

%cd /content/rvm

def recover_tree(flat_dict):
  tree = {}
  for k, v in flat_dict.items():
    parts = k.split("/")
    node = tree
    for part in parts[:-1]:
      if part not in node:
        node[part] = {}
      node = node[part]
    node[parts[-1]] = v
  return tree

def flatten_flax_params(params, parent_key=""):
  """
  Flatten nested Flax params dict into {'a.b.c': subdict}.
  """
  items = {}
  for k, v in params.items():
    new_key = f"{parent_key}.{k}" if parent_key else k
    if isinstance(v, dict):
      items.update(flatten_flax_params(v, new_key))
    else:
      items[new_key] = v
  return items

def flax_to_torch(flat_flax, torch_model):
  for name, param in torch_model.named_parameters():
    # Normalize naming
    name_fixed = name.replace('layers.', 'layers_')
    name_fixed = name_fixed.replace('blocks.', 'blocks_')

    flax_key = None

    if name == "cls_token":
      flax_key = "cls_token"

    elif name.endswith("weight"):
      # Try Linear/Conv kernels
      flax_key = name_fixed.replace("weight", "kernel")
      if flax_key not in flat_flax:
        # Try LayerNorm scale
        flax_key = name_fixed.replace("weight", "scale")

    elif name.endswith("bias"):
      flax_key = name_fixed  # bias names usually match directly

    if flax_key is None or flax_key not in flat_flax:
      print(f"[WARN] Missing weights for {name} (flax_key={flax_key})")
      continue

    # Load array
    array = np.array(flat_flax[flax_key])
    tensor = torch.tensor(array)

    # Handle Conv2d kernel
    if param.ndim == 4:
      # Flax: [H, W, in, out] → Torch: [out, in, H, W]
      if tensor.ndim == 5 and tensor.shape[0] == 1:  # Sometimes an extra batch dim
        tensor = tensor[0]
      tensor = tensor.permute(3, 2, 0, 1)

    # Handle Linear kernels
    elif param.ndim == 2:
      if tensor.ndim == 2:
        # Dense: [in, out] → [out, in]
        tensor = tensor.T
      elif tensor.ndim == 3:
        # DenseGeneral
        if param.shape[0] == tensor.shape[-1] * tensor.shape[-2]:  # Q/K/V projection
          tensor = tensor.reshape(tensor.shape[0], -1).T
        else:  # Output projection
          tensor = tensor.reshape(-1, tensor.shape[-1]).T
      else:
        raise ValueError(f"Unexpected kernel shape {tensor.shape} for {name}")

    # Reshape if needed (bias, cls_token, norm, etc.)
    tensor = tensor.reshape(param.shape)

    with torch.no_grad():
      param.copy_(tensor)

    print(f"Loaded {name} from {flax_key}")

restored_params = recover_tree(np.load("pretrain_rvm_large16_256_xid175558463_wid1.npz", allow_pickle=False))

flat_flax = flatten_flax_params(restored_params)
flax_to_torch(flat_flax, model)

In [None]:
# @title Compute Memory and Latency {form-width: "20%"}

import time

for num_frames in [16, 32, 64, 128, 256, 512, 1024]:
  inputs = torch.randn(1, num_frames, 256, 256, 3).cuda()

  # Measure latency
  n_runs = 1
  start = time.time()

  for _ in range(n_runs):
    state = None
    for t in range(num_frames):
      output, state = model(inputs[:, t], state)

  torch.cuda.synchronize()  # Wait again before stopping timer
  end = time.time()

  avg_latency = (end - start) / n_runs  # seconds per inference

  print(f"Average Latency: {avg_latency*1000:.2f} ms")