In [None]:
# @title Install dependency

!pip install mediapy

In [None]:
# @title Imports

import functools
import glob
import math
import random
import os

import cv2
import einops
from PIL import Image
import scipy.io as sio
import mediapy as media
import numpy as np
import pandas as pd
import pickle
import queue
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
import jax
import jax.numpy as jnp
import sklearn

### Prepare videos for qualitative visualization

In [None]:
# @title Download DAVIS videos

video_names = [
    'goat',
    # 'india',
    # 'soapbox',
]

for video_name in video_names:
  !wget https://storage.googleapis.com/dm-tapnet/davis2017/{video_name}.mp4

davis_videos = []
for video_name in video_names:
  video = media.read_video(f"{video_name}.mp4")
  video = media.resize_video(video, (480, 880))
  davis_videos.append(video)
  media.show_video(video, fps=16, height=128, codec='gif')

In [None]:
# @title Download MoCA videos

video_names = [
    'scorpionfish_0',
    # 'snow_leopard_8',
    # 'white_tailed_ptarmigan',
    # 'grasshopper_1',
    # 'plaice',
    # 'pygmy_seahorse_0',
    # 'stick_insect_0',
    # 'crab',
    # 'crab_1',
    # 'copperhead_snake',
]

for video_name in video_names:
  !wget https://storage.googleapis.com/dm-tapnet/MoCA/{video_name}.mp4

moca_videos = []
for video_name in video_names:
  video = media.read_video(f"{video_name}.mp4")
  video = media.resize_video(video, (360, 640))
  moca_videos.append(video)
  media.show_video(video, fps=16, width=256, codec='gif')

In [None]:
# @title Download Dalmatian videos

!wget https://storage.googleapis.com/dm-tapnet/tmp/dalmatian_illusion.mp4

dalmatian_video = media.read_video("dalmatian_illusion.mp4")
dalmatian_video = media.resize_video(dalmatian_video, (360, 480))
media.show_video(dalmatian_video, fps=16, width=256, codec='gif')

In [None]:
# @title Create noise video

def create_noise_movie(T, h, w):
  background = np.random.rand(h, w, 3)
  square = np.random.rand(h // 2, w // 2, 3)
  movie = []
  for t in range(T):
    frame = background.copy()
    # Calculate the position of the square
    x = (t * (w - w // 2)) // T  # Move across the width
    y = (h - h // 2) // 2  # Centered vertically
    frame[y:y + h // 2, x:x + w // 2] = square
    movie.append(frame)
  return (np.array(movie) * 255.0).astype(np.uint8)

# noise_video = create_noise_movie(32, 256, 256)
noise_video = create_noise_movie(64, 512, 512)
media.show_video(noise_video, fps=16, height=128, codec='gif')

In [None]:
# @title Combine qualitative videos

videos = davis_videos + moca_videos + [noise_video]

### Prepare datasets and metrics for quantitative evaluation

#### Download datasets

In [None]:
# @title Download DAVIS-2017 dataset

%cd /content
!git clone https://github.com/davisvideochallenge/davis-2017
%cd /content/davis-2017
!./data/get_davis.sh

In [None]:
# @title Download JHMDB dataset

%cd /content
!wget https://storage.googleapis.com/dm-tapnet/jhmdb.zip
!unzip jhmdb.zip

In [None]:
# @title Download VIP dataset

%cd /content
!wget https://storage.googleapis.com/dm-tapnet/VIP.zip
!unzip VIP.zip

#### Main Functions

In [None]:
# @title Dataset functions

def create_davis_dataset(root_path):
  """DAVIS dataset, including fields required for video segmentation evaluation."""
  for video_id in tf.io.gfile.GFile(os.path.join(root_path, 'ImageSets/2017/val.txt'), 'r'):
    video_id = video_id.strip()
    frame_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'JPEGImages/480p', video_id, '*.jpg')))
    frames = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb')).convert('RGB')) for f in frame_files])
    label_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'Annotations/480p', video_id, '*.png')))
    labels = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb'))) for f in label_files])
    yield {'video': frames, 'mask': labels, 'video_id': video_id}

def create_jhmdb_dataset(jhmdb_path):
  """JHMDB dataset, including fields required for PCK evaluation."""
  video_ids = []
  for file in os.listdir(os.path.join(jhmdb_path, 'splits')):
    if file.endswith('split1.txt'):
      video_folder = '_'.join(file.split('_')[:-2])
      lines = open(os.path.join(jhmdb_path, 'splits', file), 'r').readlines()
      for line in lines:
        video_name, traintest = line.split()
        if int(traintest) == 2:
          video_id = os.path.join(video_folder, video_name.split('.')[0])
          video_ids.append(video_id)

  random.shuffle(video_ids)
  for video_id in video_ids:
    joints = os.path.join(jhmdb_path, 'joint_positions', video_id, 'joint_positions.mat')
    pose = sio.loadmat(joints)['pos_img'] - 1  # matlab -> python
    pose = pose.astype(np.float32).transpose(2, 1, 0)  # (num_frames, num_points, 2)
    frame_files = sorted(glob.glob(os.path.join(jhmdb_path, 'Rename_Images', video_id, '*.png')))
    frames = np.stack([np.array(Image.open(open(f, 'rb')).convert('RGB')) for f in frame_files])
    yield {'video': frames, 'pose': pose, 'video_id': video_id}

def create_vip_dataset(root_path):
  """VIP dataset, including fields required for video segmentation evaluation."""
  for video_id in tf.io.gfile.GFile(os.path.join(root_path, 'lists/val_videos.txt'), 'r'):
    video_id = video_id.strip()
    frame_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'Images', video_id, '*.jpg')))
    frames = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb')).convert('RGB')) for f in frame_files])
    label_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'Annotations/Category_ids', video_id, '*.png')))
    labels = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb'))) for f in label_files])
    yield {'video': frames, 'mask': labels, 'video_id': video_id}

def create_tapvid_davis_dataset(pickle_path):
  """TAPVid-DAVIS dataset, including fields required for point tracking evaluation."""
  with tf.io.gfile.GFile(pickle_path, 'rb') as f:
    point_tracks = pickle.load(f)

  for video_id in point_tracks.keys():
    frames = point_tracks[video_id]['video']
    points = point_tracks[video_id]['points']
    occluded = point_tracks[video_id]['occluded']

    points = points * np.array([frames.shape[2], frames.shape[1]])
    points = points.astype(np.float32).transpose(1, 0, 2)  # (num_frames, num_points, 2)
    visible = np.logical_not(occluded)
    visible = visible.transpose(1, 0)  # (num_frames, num_points)

    valid = visible[0]
    points = points[:, valid]
    visible = visible[:, valid]

    yield {'video': frames, 'points': points, 'visible': visible, 'video_id': video_id}

In [None]:
# @title Label propagation functions

def draw_labelmap_np(img, pt, sigma=0.5):
  # Draw a 2D gaussian
  # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py

  # Check that any part of the gaussian is in-bounds
  ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
  br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
  if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0):
    # If not, just return the image as is
    return img

  # Generate gaussian
  size = 6 * sigma + 1
  x = np.arange(0, size, 1, float)
  y = x[:, np.newaxis]
  x0 = y0 = size // 2
  g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))

  # Usable gaussian range
  g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
  g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
  # Image range
  img_x = max(0, ul[0]), min(br[0], img.shape[1])
  img_y = max(0, ul[1]), min(br[1], img.shape[0])

  img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
  return img

def mask2heatmap(mask, h, w, height, width):
  """Convert segmentation mask to heatmap (resize and one-hot encode)"""
  label_map = np.unique(mask)
  mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
  heatmap = np.stack([mask == l for l in label_map], axis=-1).astype(np.float32)
  heatmap = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR)
  return heatmap, label_map

def pose2heatmap(pose, h, w, height, width):
  """Convert pose to heatmap (resize and one-hot encode)"""
  n_class = pose.shape[0]
  coord = pose * np.array([w / width, h / height])
  heatmap = np.zeros((h, w, n_class + 1), dtype=np.float32)
  for i in range(n_class):
    heatmap[..., i + 1] = draw_labelmap_np(np.zeros((h, w)), coord[i])
  heatmap[..., 0] = heatmap.sum(axis=-1) == 0
  return heatmap

def process_pose(preds, h, w, ori_height, ori_width, topk=5):
  current_coords, jnt_visibles = [], []
  for t in range(preds.shape[0]):
    pred = preds[t][..., 1:]
    flatlbls = pred.flatten(0, 1)
    vals, ids = torch.topk(flatlbls, k=topk, dim=0)
    vals /= vals.sum(0)[None]
    xx, yy = ids % pred.shape[1], ids // pred.shape[1]
    current_coord = torch.stack([(xx * vals).sum(0), (yy * vals).sum(0)], dim=0)
    current_coord[0, :] = current_coord[0, :] / w * ori_width
    current_coord[1, :] = current_coord[1, :] / h * ori_height
    current_coord[:, flatlbls.sum(0) == 0] = -1
    current_coord = current_coord.cpu().numpy().transpose(1, 0)
    jnt_visible = (current_coord[:, 0] >= 0).astype(float)  # (n_class)
    current_coords.append(current_coord)
    jnt_visibles.append(jnt_visible)
  current_coords = np.stack(current_coords, axis=0)
  jnt_visibles = np.stack(jnt_visibles, axis=0)
  return current_coords, jnt_visibles

def process_points(preds, h, w, ori_height, ori_width, topk=5):
  points = []
  for t in range(preds.shape[0]):
    pred = preds[t][..., 1:]
    flatlbls = pred.flatten(0, 1)
    vals, ids = torch.topk(flatlbls, k=topk, dim=0)
    vals /= vals.sum(0)[None]
    xx, yy = ids % pred.shape[1], ids // pred.shape[1]
    point = torch.stack([(xx * vals).sum(0), (yy * vals).sum(0)], dim=0)
    point[0, :] = point[0, :] / w * ori_width
    point[1, :] = point[1, :] / h * ori_height
    point[:, flatlbls.sum(0) == 0] = -1
    point = point.cpu().numpy().transpose(1, 0)
    points.append(point)
  points = np.stack(points, axis=0)
  return points

def process_segmentation(preds, lbl_map, height, width, ori_height, ori_width):
  pred_lbls = []
  for t in range(preds.shape[0]):
    pred = preds[t]
    pred = pred.cpu().numpy()
    # Upsample predicted soft label maps
    pred_dist = cv2.resize(pred, (width, height))[:]
    # Argmax to get the hard label for index
    pred_lbl = np.argmax(pred_dist, axis=-1)
    pred_lbl = np.array(lbl_map, dtype=np.int32)[pred_lbl]
    pred_lbl = cv2.resize(pred_lbl, (ori_width, ori_height), interpolation=cv2.INTER_NEAREST_EXACT)
    pred_lbls.append(pred_lbl)
  pred_lbls = np.stack(pred_lbls, axis=0)
  return pred_lbls

def label_propagation(feats, heatmap, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False):
  """Propagation of the heatmap based on feature similarity."""

  # Creates a mask indicating valid neighbors for each grid element.
  h, w = feats.shape[1], feats.shape[2]
  gx, gy = torch.meshgrid(torch.arange(0, h), torch.arange(0, w), indexing="ij")  # (h, w)
  D = ((gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2).float() ** 0.5
  D = (D < radius).float().to('cuda')
  D[D == 0] = -1e10
  D[D == 1] = 0
  D = D.permute(2, 3, 0, 1)  # (h, w, h, w)

  # The queue stores the context frames
  que = queue.Queue(n_context)
  for _ in range(n_context):
    que.put([feats[0], heatmap])

  preds = []
  for t in range(feats.shape[0]):
    # Use first and previous frames as context
    ctx_feats = torch.stack([feats[0]] + [pair[0] for pair in que.queue])
    ctx_lbls = torch.stack([heatmap] + [pair[1] for pair in que.queue])

    aff = torch.einsum('hwc, tmnc -> hwtmn', feats[t], ctx_feats) / temperature  # (h, w, n_context+1, h, w)
    if restrict_neighborhood:
      aff[:, :, 1:] += D[:, :, None]  # (h, w, n_context+1, h, w)
    aff = aff.view(aff.shape[0], aff.shape[1], -1)  # (h, w, n_context+1 * h * w)

    weights, ids = torch.topk(aff, topk, dim=-1)  # (h, w, topk), (h, w, topk)
    weights = F.softmax(weights, dim=-1)  # (h, w, topk)
    ctx_lbls = ctx_lbls.view(-1, ctx_lbls.shape[-1])  # (n_context+1 * h * w, n_class)
    pred = torch.einsum('hwlk, hwl -> hwk', ctx_lbls[ids], weights) # (h, w, n_class)

    if que.qsize() == n_context:
      que.get()
    que.put([feats[t], pred])

    if norm_mask:
      pred -= pred.min(-1)[0][..., None]
      pred /= pred.max(-1)[0][..., None]

    preds.append(pred)
  preds = torch.stack(preds)
  return preds

In [None]:
# @title Visualization functions

def visualize_pca(features, video):
  pca = sklearn.decomposition.PCA(n_components=3, whiten=True)
  pca_data = pca.fit_transform(einops.rearrange(features, 't n m c -> (t n m) c'))
  pca_data = pca_data.reshape(features.shape[:-1] + (3,))
  pca_video = (pca_data - pca_data.min()) / (pca_data.max() - pca_data.min())
  pca_video = jax.image.resize(pca_video, video.shape, method='nearest')
  return pca_video

def segmentations_to_video(masks):
  num_objects = np.max(masks)  # assume consecutive numbering
  # palette = [(0, 0, 0)] + sns.color_palette(n_colors=num_objects)
  palette = sns.color_palette(n_colors=num_objects + 1)
  video = np.zeros((masks.shape[0], masks.shape[1], masks.shape[2], 3))
  for i in range(num_objects + 1):
    video[masks == i] = palette[i]
  return video

def visualize_kmeans(features, video, n_clusters=5):
  kmeans = sklearn.cluster.KMeans(n_clusters, init='k-means++')
  result = kmeans.fit(einops.rearrange(features, 't n m c -> (t n m) c'))
  kmeans_labels = jnp.reshape(result.labels_, features.shape[:-1])
  kmeans_video = segmentations_to_video(kmeans_labels)
  kmeans_video = jax.image.resize(kmeans_video, video.shape, method='nearest')
  return kmeans_video

#### Evaluation Metrics

In [None]:
# @title DAVIS video segmentation evaluation metric
# https://github.com/davisvideochallenge/davis2017-evaluation/blob/master/davis2017/metrics.py

import math
import cv2

def db_eval_iou(annotation, segmentation, void_pixels=None):
  """ Compute region similarity as the Jaccard Index.
  Arguments:
      annotation   (ndarray): binary annotation   map.
      segmentation (ndarray): binary segmentation map.
      void_pixels  (ndarray): optional mask with void pixels

  Return:
      jaccard (float): region similarity
  """
  assert annotation.shape == segmentation.shape, \
      f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
  annotation = annotation.astype(bool)
  segmentation = segmentation.astype(bool)

  if void_pixels is not None:
    assert annotation.shape == void_pixels.shape, \
        f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
    void_pixels = void_pixels.astype(bool)
  else:
    void_pixels = np.zeros_like(segmentation)

  # Intersection between all sets
  inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
  union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))

  j = inters / union
  if j.ndim == 0:
    j = 1 if np.isclose(union, 0) else j
  else:
    j[np.isclose(union, 0)] = 1
  return j


def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
  assert annotation.shape == segmentation.shape
  if void_pixels is not None:
    assert annotation.shape == void_pixels.shape
  if annotation.ndim == 3:
    n_frames = annotation.shape[0]
    f_res = np.zeros(n_frames)
    for frame_id in range(n_frames):
      void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
      f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
  elif annotation.ndim == 2:
    f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
  else:
    raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
  return f_res


def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
  """
  Compute mean,recall and decay from per-frame evaluation.
  Calculates precision/recall for boundaries between foreground_mask and
  gt_mask using morphological operators to speed it up.

  Arguments:
      foreground_mask (ndarray): binary segmentation image.
      gt_mask         (ndarray): binary annotated image.
      void_pixels     (ndarray): optional mask with void pixels

  Returns:
      F (float): boundaries F-measure
  """
  assert np.atleast_3d(foreground_mask).shape[2] == 1
  if void_pixels is not None:
    void_pixels = void_pixels.astype(bool)
  else:
    void_pixels = np.zeros_like(foreground_mask).astype(bool)

  bound_pix = bound_th if bound_th >= 1 else \
      np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))

  # Get the pixel boundaries of both masks
  fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
  gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))

  from skimage.morphology import disk

  # fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
  fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
  # gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
  gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))

  # Get the intersection
  gt_match = gt_boundary * fg_dil
  fg_match = fg_boundary * gt_dil

  # Area of the intersection
  n_fg = np.sum(fg_boundary)
  n_gt = np.sum(gt_boundary)

  # % Compute precision and recall
  if n_fg == 0 and n_gt > 0:
    precision = 1
    recall = 0
  elif n_fg > 0 and n_gt == 0:
    precision = 0
    recall = 1
  elif n_fg == 0 and n_gt == 0:
    precision = 1
    recall = 1
  else:
    precision = np.sum(fg_match) / float(n_fg)
    recall = np.sum(gt_match) / float(n_gt)

  # Compute F measure
  if precision + recall == 0:
    F = 0
  else:
    F = 2 * precision * recall / (precision + recall)

  return F


def _seg2bmap(seg, width=None, height=None):
  """
  From a segmentation, compute a binary boundary map with 1 pixel wide
  boundaries.  The boundary pixels are offset by 1/2 pixel towards the
  origin from the actual segment boundary.
  Arguments:
      seg     : Segments labeled from 1..k.
      width   : Width of desired bmap  <= seg.shape[1]
      height  : Height of desired bmap <= seg.shape[0]
  Returns:
      bmap (ndarray): Binary boundary map.
   David Martin <dmartin@eecs.berkeley.edu>
   January 2003
  """

  seg = seg.astype(bool)
  seg[seg > 0] = 1

  assert np.atleast_3d(seg).shape[2] == 1

  width = seg.shape[1] if width is None else width
  height = seg.shape[0] if height is None else height

  h, w = seg.shape[:2]

  ar1 = float(width) / float(height)
  ar2 = float(w) / float(h)

  assert not (
      width > w | height > h | abs(ar1 - ar2) > 0.01
  ), "Can't convert %dx%d seg to %dx%d bmap." % (w, h, width, height)

  e = np.zeros_like(seg)
  s = np.zeros_like(seg)
  se = np.zeros_like(seg)

  e[:, :-1] = seg[:, 1:]
  s[:-1, :] = seg[1:, :]
  se[:-1, :-1] = seg[1:, 1:]

  b = seg ^ e | seg ^ s | seg ^ se
  b[-1, :] = seg[-1, :] ^ e[-1, :]
  b[:, -1] = seg[:, -1] ^ s[:, -1]
  b[-1, -1] = 0

  if w == width and h == height:
    bmap = b
  else:
    bmap = np.zeros((height, width))
    for x in range(w):
      for y in range(h):
        if b[y, x]:
          j = 1 + math.floor((y - 1) + height / h)
          i = 1 + math.floor((x - 1) + width / h)
          bmap[j, i] = 1

  return bmap

In [None]:
# @title JHMDB pck evaluation metric
# https://github.com/Liusifei/UVC/blob/jhmdb/eval_pck.py

def compute_human_boxes(gts, jnt_visible_set):
  human_boxes = []
  for nowgt, jnt_visible in zip(gts, jnt_visible_set):
    now_boxes = np.zeros(nowgt.shape[0])
    for t in range(nowgt.shape[0]):
      visible_pts = nowgt[t, jnt_visible[t] == 1]
      if visible_pts.size:
        min_pt, max_pt = visible_pts.min(axis=0), visible_pts.max(axis=0)
        now_boxes[t] = 0.6 * np.linalg.norm(max_pt - min_pt)
    human_boxes.append(now_boxes)
  return human_boxes

def compute_distances(gts, preds, human_boxes, jnt_visible_set):
  distAll = {pidx: np.array([]) for pidx in range(15)}
  for nowgt, predres, now_boxes, jnt_visible in zip(gts, preds, human_boxes, jnt_visible_set):
    for i in range(nowgt.shape[1]):
      for t in range(1, nowgt.shape[0]):
        if jnt_visible[t, i]:
          distAll[i] = np.append(distAll[i], np.linalg.norm(predres[t, i] - nowgt[t, i]) / now_boxes[t])
  return distAll

def computePCK(distAll, distThresh):
  pckAll = np.zeros(len(distAll))
  for pidx, distances in distAll.items():
    pckAll[pidx] = 100.0 * np.sum(distances <= distThresh) / len(distances)
  pck = np.mean(pckAll)
  return pck, pckAll

In [None]:
# @title VIP video segmentation evaluation metric

def fast_hist(a, b, n):
  k = (a >= 0) & (a < n)
  return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)

def compute_iou(hist):
  classes = ['background', 'hat', 'hair', 'sun-glasses', 'upper-clothes', 'dress',
           'coat', 'socks', 'pants', 'gloves', 'scarf', 'skirt', 'torso-skin',
           'face', 'right-arm', 'left-arm', 'right-leg', 'left-leg', 'right-shoe', 'left-shoe']

  num_cor_pix = np.diag(hist)  # num of correct pixels
  num_gt_pix = hist.sum(1)  # num of gt pixels
  union = num_gt_pix + hist.sum(0) - num_cor_pix
  iou_per_class = {}
  for i in range(len(classes)):
    iou_per_class[classes[i]] = num_cor_pix[i] / union[i]
  iou = num_cor_pix / (num_gt_pix + hist.sum(0) - num_cor_pix)
  iou = np.nanmean(iou)
  return iou, iou_per_class

#### Evaluation Functions

In [None]:
# @title DAVIS evaluation function

def evaluate_davis(model, extract_feature_function, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16, params=None):
  davis_dataset = create_davis_dataset(dataset_path)
  h, w = height // patch_size, width // patch_size  # feature resolution

  n_max_class = 10
  j_metrics, f_metrics, fj_metrics, counters = [], [], [], []
  for sample in tqdm.tqdm(davis_dataset):
    ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

    # Extract features
    video = media.resize_video(sample['video'], (height, width))
    if params is None:
      feats = extract_feature_function(model, video)  # PyTorch models
    else:
      feats = extract_feature_function(model, params, video)  # Jax models
    if not isinstance(feats, torch.Tensor):
      feats = torch.tensor(feats).cuda()
    feats = torch.nn.functional.normalize(feats, dim=-1)

    # Prepare downscaled first frame segmentation (resize and one-hot encode)
    lbls_small, lbl_map = mask2heatmap(sample['mask'][0], h, w, height, width)
    lbls_small = torch.tensor(lbls_small).cuda()

    pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)

    pred_lbls = process_segmentation(pred_lbls, lbl_map, height, width, ori_height, ori_width)

    n_class = int(sample['mask'].max())  # Get the number of objects in the segmentation map

    masks = sample['mask'][1:-1]
    all_gt_masks = np.eye(n_class + 1)[masks]
    all_gt_masks = all_gt_masks[..., 1:]  # Remove background class

    masks = pred_lbls[1:-1]
    all_res_masks = np.eye(n_class + 1)[masks]
    all_res_masks = all_res_masks[..., 1:]  # Remove background class

    num_frames = all_gt_masks.shape[0]
    j_metrics_res = np.zeros((n_class, num_frames))
    f_metrics_res = np.zeros((n_class, num_frames))
    for ii in range(n_class):
      j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[..., ii], all_res_masks[..., ii], None)
      f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[..., ii], all_res_masks[..., ii], None)

    JM, FM, FJM = np.zeros((n_max_class,)), np.zeros((n_max_class,)), np.zeros((n_max_class,))
    for ii in range(n_class):
      JM[ii] = np.nanmean(j_metrics_res[ii])
      FM[ii] = np.nanmean(f_metrics_res[ii])
      FJM[ii] = (JM[ii] + FM[ii]) / 2.
    counter = np.zeros((n_max_class,))
    counter[:n_class] = 1

    j_metrics.append(JM)
    f_metrics.append(FM)
    fj_metrics.append(FJM)
    counters.append(counter)

  j_metrics = np.array(j_metrics)
  f_metrics = np.array(f_metrics)
  fj_metrics = np.array(fj_metrics)
  counters = np.array(counters)

  fj_metric = (fj_metrics * counters).sum() / counters.sum()
  j_metric = (j_metrics * counters).sum() / counters.sum()
  f_metric = (f_metrics * counters).sum() / counters.sum()

  print('')
  print('J&F-Mean',fj_metric)
  print('J-Mean', j_metric)
  print('F-Mean', f_metric)

In [None]:
# @title JHMDB evaluation function

def evaluate_jhmdb(model, extract_feature_function, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16, params=None):
  jhmdb_dataset = create_jhmdb_dataset(jhmdb_path=dataset_path)
  h, w = height // patch_size, width // patch_size  # feature resolution

  gts, preds, jnt_visible_set = [], [], []
  for sample in tqdm.tqdm(jhmdb_dataset):
    ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

    # Extract features
    video = media.resize_video(sample['video'], (height, width))
    if params is None:
      feats = extract_feature_function(model, video)  # PyTorch models
    else:
      feats = extract_feature_function(model, params, video)  # Jax models
    if not isinstance(feats, torch.Tensor):
      feats = torch.tensor(feats).cuda()
    feats = torch.nn.functional.normalize(feats, dim=-1)

    # Prepare downscaled first frame heatmap (resize and one-hot encode)
    lbls_small = pose2heatmap(sample['pose'][0], h, w, ori_height, ori_width)
    lbls_small = torch.tensor(lbls_small).cuda()

    pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)

    pred_lbls, jnt_visible = process_pose(pred_lbls, h, w, ori_height, ori_width)
    preds.append(pred_lbls)
    jnt_visible_set.append(jnt_visible)
    gts.append(sample['pose'])

  human_boxes = compute_human_boxes(gts, jnt_visible_set)
  distAll = compute_distances(gts, preds, human_boxes, jnt_visible_set)

  for threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
    print(f"PCK@{threshold}: {computePCK(distAll, threshold)[0]}")

In [None]:
# @title VIP evaluation function

def evaluate_vip(model, extract_feature_function, dataset_path='/content/VIP/', height=448, width=880, patch_size=16, params=None):
  vip_dataset = create_vip_dataset(root_path=dataset_path)
  h, w = height // patch_size, width // patch_size  # feature resolution

  n_class = 20
  hist = np.zeros((n_class, n_class))
  for sample in tqdm.tqdm(vip_dataset):
    ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

    # Extract features
    video = media.resize_video(sample['video'], (height, width))
    if params is None:
      feats = extract_feature_function(model, video)  # PyTorch models
    else:
      feats = extract_feature_function(model, params, video)  # Jax models
    if not isinstance(feats, torch.Tensor):
      feats = torch.tensor(feats).cuda()
    feats = torch.nn.functional.normalize(feats, dim=-1)

    # Prepare downscaled first frame segmentation (resize and one-hot encode)
    lbls_small, lbl_map = mask2heatmap(sample['mask'][0], h, w, height, width)
    lbls_small = torch.tensor(lbls_small).cuda()

    pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=10, radius=20, restrict_neighborhood=True, norm_mask=False)

    pred_lbls = process_segmentation(pred_lbls, lbl_map, height, width, ori_height, ori_width)

    for t in range(pred_lbls.shape[0]):
      hist += fast_hist(sample['mask'][t], pred_lbls[t], n_class)

  iou, iou_per_class = compute_iou(hist)

  print('')
  for key in iou_per_class:
    print('%-15s: %f' % (key, iou_per_class[key]))
  print ('mean IoU', iou)

### VideoMAE

In [None]:
# @title Load VideoMAE model

import transformers

def get_sinusoid_encoding_table(n_position, d_hid):
  """Sinusoid position encoding table"""

  # TODO: make it with torch instead of numpy
  def get_position_angle_vec(position):
    return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

  return torch.FloatTensor(sinusoid_table).unsqueeze(0)

class VideoMAEPatchEmbeddings(nn.Module):

  def __init__(self, config):
    super().__init__()

    image_size = config.image_size
    patch_size = config.patch_size
    num_channels = config.num_channels
    hidden_size = config.hidden_size
    num_frames = config.num_frames
    tubelet_size = config.tubelet_size

    image_size = image_size if isinstance(image_size, (tuple, list)) else (image_size, image_size)
    patch_size = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
    self.image_size = image_size
    self.patch_size = patch_size
    self.tubelet_size = int(tubelet_size)
    self.num_channels = num_channels
    self.num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
    self.projection = nn.Conv3d(
      in_channels=num_channels,
      out_channels=hidden_size,
      kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
      stride=(self.tubelet_size, patch_size[0], patch_size[1]),
    )

  def forward(self, pixel_values):
    # permute to (batch_size, num_channels, num_frames, height, width)
    pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
    return embeddings

class VideoMAEEmbeddings(nn.Module):

  def __init__(self, config):
    super().__init__()
    self.patch_embeddings = VideoMAEPatchEmbeddings(config)
    self.num_patches = self.patch_embeddings.num_patches
    self.embedding_shape = (8, 14, 14)  # Manually added for resizing position embedding
    # fixed sin-cos embedding
    self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
    self.config = config

  # def forward(self, pixel_values):
  #     embeddings = self.patch_embeddings(pixel_values)
  #     embeddings = embeddings + self.position_embeddings.detach().type_as(embeddings).to(device=embeddings.device, copy=True)
  #     return embeddings

  def interpolate_pos_encoding(self, x, h, w):
    x = x.reshape(self.embedding_shape + (-1,))
    dim = x.shape[-1]
    x = F.interpolate(
      x.permute(0, 3, 1, 2),
      scale_factor=(h / self.embedding_shape[-2], w / self.embedding_shape[-1]),
      mode="bicubic",
    )
    x = x.permute(0, 2, 3, 1).view(1, -1, dim)
    return x

  def forward(self, pixel_values):
    embeddings = self.patch_embeddings(pixel_values)
    position_embeddings = self.position_embeddings.to(embeddings.device)
    _, _, _, h, w = pixel_values.shape
    h = h // self.patch_embeddings.patch_size[0]
    w = w // self.patch_embeddings.patch_size[1]
    embeddings = embeddings + self.interpolate_pos_encoding(position_embeddings, h, w)
    return embeddings

class VideoMAEModel(transformers.VideoMAEPreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.embeddings = VideoMAEEmbeddings(config)
    self.encoder = transformers.models.videomae.modeling_videomae.VideoMAEEncoder(config)

    if config.use_mean_pooling:
      self.layernorm = None
    else:
      self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    self.post_init()

  def forward(self, pixel_values):
    embedding_output = self.embeddings(pixel_values)

    encoder_outputs = self.encoder(
      embedding_output,
      head_mask=None,
    )
    sequence_output = encoder_outputs[0]
    if self.layernorm is not None:
      sequence_output = self.layernorm(sequence_output)

    return sequence_output

model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-large")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)
print(model)

In [None]:
# @title Feature extraction function

WINDOW_LENGTH = 16
PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % WINDOW_LENGTH: # VideoMAE requires 16 frames per clip
    video = np.pad(video, ((0, WINDOW_LENGTH - num_frames % WINDOW_LENGTH), (0, 0), (0, 0), (0, 0)))

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(0, video.shape[0], WINDOW_LENGTH):
    feature = model(video[t : t + WINDOW_LENGTH][None])[0]  # (h * w + 1, c)
    feature = feature.view(WINDOW_LENGTH // 2, h, w, -1)  # (h, w, c)
    feature = F.interpolate(feature.permute(3, 0, 1, 2)[None], size=(WINDOW_LENGTH, h, w), mode='trilinear')
    features.append(feature[0].permute(1, 2, 3, 0))
  features = torch.concatenate(features, dim=0)  # (T, h, w, c)
  features = features[:num_frames]
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### VideoMAE from Github

In [None]:
# @title Download VideoMAE code

%cd /content
!git clone https://github.com/MCG-NJU/VideoMAE.git

In [None]:
# @title Download VideoMAE checkpoint

%cd /content/VideoMAE

import gdown

# file_id = "1AJQR1Rsi2N1pDn9tLyJ8DQrUREiBA1bO"
# gdown.download(f"https://drive.google.com/uc?id={file_id}", "pretrain_videomae_huge_patch16_224.pth", quiet=False)
file_id = "1qLOXWb_MGEvaI7tvuAe94CV7S2HXRwT3"
gdown.download(f"https://drive.google.com/uc?id={file_id}", "pretrain_videomae_large_patch16_224.pth", quiet=False)

In [None]:
# @title Load VideoMAE model

%cd /content/VideoMAE

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from functools import partial
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

class Mlp(nn.Module):
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features, out_features)
    self.drop = nn.Dropout(drop)

  def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    # x = self.drop(x)
    # commit this for the orignal BERT implement
    x = self.fc2(x)
    x = self.drop(x)
    return x

class Attention(nn.Module):
  def __init__(
      self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
      proj_drop=0., attn_head_dim=None):
    super().__init__()
    self.num_heads = num_heads
    head_dim = dim // num_heads
    if attn_head_dim is not None:
      head_dim = attn_head_dim
    all_head_dim = head_dim * self.num_heads
    self.scale = qk_scale or head_dim ** -0.5

    self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
    if qkv_bias:
      self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
      self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
    else:
      self.q_bias = None
      self.v_bias = None

    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(all_head_dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)

  def forward(self, x):
    B, N, C = x.shape
    qkv_bias = None
    if self.q_bias is not None:
      qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
    # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
    qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

    q = q * self.scale
    attn = (q @ k.transpose(-2, -1))


    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

class Block(nn.Module):

  def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
               drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
               attn_head_dim=None):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.attn = Attention(
      dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
      attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    if init_values > 0:
      self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
      self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
    else:
      self.gamma_1, self.gamma_2 = None, None

  def forward(self, x):
    if self.gamma_1 is None:
      x = x + self.drop_path(self.attn(self.norm1(x)))
      x = x + self.drop_path(self.mlp(self.norm2(x)))
    else:
      x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
      x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
    return x

class PatchEmbed(nn.Module):
  """ Image to Patch Embedding
  """
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
    super().__init__()
    img_size = to_2tuple(img_size)
    patch_size = to_2tuple(patch_size)
    self.tubelet_size = int(tubelet_size)
    num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
                          kernel_size = (self.tubelet_size,  patch_size[0],patch_size[1]),
                          stride=(self.tubelet_size,  patch_size[0],  patch_size[1]))

  def forward(self, x, **kwargs):
    B, C, T, H, W = x.shape
    # FIXME look at relaxing size constraints
    # assert H == self.img_size[0] and W == self.img_size[1], \
    #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
  ''' Sinusoid position encoding table '''
  # TODO: make it with torch instead of numpy
  def get_position_angle_vec(position):
    return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

  return  torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)

class PretrainVisionTransformerEncoder(nn.Module):
  """ Vision Transformer with support for patch or hybrid CNN input stage
  """
  def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
               num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
               drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False,
               use_learnable_pos_emb=False):
    super().__init__()
    self.num_classes = num_classes
    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
    self.patch_embed = PatchEmbed(
      img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tubelet_size=tubelet_size)
    num_patches = self.patch_embed.num_patches
    self.embedding_shape = (8, 14, 14)  # Manually added for resizing position embedding
    self.use_checkpoint = use_checkpoint

    # TODO: Add the cls token
    if use_learnable_pos_emb:
      self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
    else:
      # sine-cosine positional embeddings
      self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
    self.blocks = nn.ModuleList([
      Block(
        dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
        init_values=init_values)
      for i in range(depth)])
    self.norm = norm_layer(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

  def interpolate_pos_encoding(self, x, h, w):
    x = x.reshape(self.embedding_shape + (-1,))
    dim = x.shape[-1]
    x = F.interpolate(
      x.permute(0, 3, 1, 2),
      scale_factor=(h / self.embedding_shape[-2], w / self.embedding_shape[-1]),
      mode="bicubic",
    )
    x = x.permute(0, 2, 3, 1).view(1, -1, dim)
    return x

  def forward_features(self, x):
    _, _, T, h, w = x.shape
    h = h // self.patch_embed.patch_size[0]
    w = w // self.patch_embed.patch_size[1]
    x = self.patch_embed(x)

    # x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
    pos_embed = self.pos_embed.to(x.device)
    x = x + self.interpolate_pos_encoding(pos_embed, h, w)

    B, _, C = x.shape
    x_vis = x.reshape(B, -1, C) # ~mask means visible

    if self.use_checkpoint:
      for blk in self.blocks:
        x_vis = checkpoint.checkpoint(blk, x_vis)
    else:
      for blk in self.blocks:
        x_vis = blk(x_vis)

    x_vis = self.norm(x_vis)
    return x_vis

  def forward(self, x):
    x = self.forward_features(x)
    # x = self.head(x)
    return x

class PretrainVisionTransformerDecoder(nn.Module):
  """ Vision Transformer with support for patch or hybrid CNN input stage
  """
  def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
               qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
               norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False
               ):
    super().__init__()
    self.num_classes = num_classes
    assert num_classes == 3 * tubelet_size * patch_size ** 2
    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
    self.patch_size = patch_size
    self.use_checkpoint = use_checkpoint

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
    self.blocks = nn.ModuleList([
      Block(
        dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
        init_values=init_values)
      for i in range(depth)])
    self.norm = norm_layer(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    self.apply(self._init_weights)

  def _init_weights(self, m):
    if isinstance(m, nn.Linear):
      nn.init.xavier_uniform_(m.weight)
      if isinstance(m, nn.Linear) and m.bias is not None:
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
      nn.init.constant_(m.bias, 0)
      nn.init.constant_(m.weight, 1.0)

  def get_num_layers(self):
    return len(self.blocks)

  @torch.jit.ignore
  def no_weight_decay(self):
    return {'pos_embed', 'cls_token'}

  def get_classifier(self):
    return self.head

  def reset_classifier(self, num_classes, global_pool=''):
    self.num_classes = num_classes
    self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

  def forward(self, x, return_token_num):
    if self.use_checkpoint:
      for blk in self.blocks:
        x = checkpoint.checkpoint(blk, x)
    else:
      for blk in self.blocks:
        x = blk(x)

    if return_token_num > 0:
      x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
    else:
      x = self.head(self.norm(x))

    return x

class PretrainVisionTransformer(nn.Module):
  """ Vision Transformer with support for patch or hybrid CNN input stage
  """
  def __init__(self,
               img_size=224,
               patch_size=16,
               encoder_in_chans=3,
               encoder_num_classes=0,
               encoder_embed_dim=768,
               encoder_depth=12,
               encoder_num_heads=12,
               decoder_num_classes=1536, #  decoder_num_classes=768,
               decoder_embed_dim=512,
               decoder_depth=8,
               decoder_num_heads=8,
               mlp_ratio=4.,
               qkv_bias=False,
               qk_scale=None,
               drop_rate=0.,
               attn_drop_rate=0.,
               drop_path_rate=0.,
               norm_layer=nn.LayerNorm,
               init_values=0.,
               use_learnable_pos_emb=False,
               use_checkpoint=False,
               tubelet_size=2,
               num_classes=0, # avoid the error from create_fn in timm
               in_chans=0, # avoid the error from create_fn in timm
               ):
    super().__init__()
    self.encoder = PretrainVisionTransformerEncoder(
      img_size=img_size,
      patch_size=patch_size,
      in_chans=encoder_in_chans,
      num_classes=encoder_num_classes,
      embed_dim=encoder_embed_dim,
      depth=encoder_depth,
      num_heads=encoder_num_heads,
      mlp_ratio=mlp_ratio,
      qkv_bias=qkv_bias,
      qk_scale=qk_scale,
      drop_rate=drop_rate,
      attn_drop_rate=attn_drop_rate,
      drop_path_rate=drop_path_rate,
      norm_layer=norm_layer,
      init_values=init_values,
      tubelet_size=tubelet_size,
      use_checkpoint=use_checkpoint,
      use_learnable_pos_emb=use_learnable_pos_emb)

    # self.decoder = PretrainVisionTransformerDecoder(
    #     patch_size=patch_size,
    #     num_patches=self.encoder.patch_embed.num_patches,
    #     num_classes=decoder_num_classes,
    #     embed_dim=decoder_embed_dim,
    #     depth=decoder_depth,
    #     num_heads=decoder_num_heads,
    #     mlp_ratio=mlp_ratio,
    #     qkv_bias=qkv_bias,
    #     qk_scale=qk_scale,
    #     drop_rate=drop_rate,
    #     attn_drop_rate=attn_drop_rate,
    #     drop_path_rate=drop_path_rate,
    #     norm_layer=norm_layer,
    #     init_values=init_values,
    #     tubelet_size=tubelet_size,
    #     use_checkpoint=use_checkpoint)

    # self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)

    self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

    self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)

  def forward(self, x):
    _, _, T, _, _ = x.shape
    x_vis = self.encoder(x) # [B, N_vis, C_e]
    # x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
    # B, N, C = x_vis.shape
    # we don't unshuffle the correct visible token order,
    # but shuffle the pos embedding accorddingly.
    # expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
    # pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
    # pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
    # x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
    # x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]

    # return x
    return x_vis

def pretrain_videomae_small_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=384,
      encoder_depth=12,
      encoder_num_heads=6,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=192,
      decoder_num_heads=3,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

def pretrain_videomae_base_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=768,
      encoder_depth=12,
      encoder_num_heads=12,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=384,
      decoder_num_heads=6,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

def pretrain_videomae_large_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=1024,
      encoder_depth=24,
      encoder_num_heads=16,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=512,
      decoder_num_heads=8,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

def pretrain_videomae_huge_patch16_224():
  model = PretrainVisionTransformer(
      img_size=224,
      patch_size=16,
      encoder_embed_dim=1280,
      encoder_depth=32,
      encoder_num_heads=16,
      encoder_num_classes=0,
      decoder_num_classes=1536,
      decoder_embed_dim=640,
      decoder_num_heads=8,
      mlp_ratio=4,
      qkv_bias=True,
      norm_layer=partial(nn.LayerNorm, eps=1e-6))
  return model

model = pretrain_videomae_large_patch16_224()
checkpoint = torch.load('pretrain_videomae_large_patch16_224.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
# model = pretrain_videomae_huge_patch16_224()
# checkpoint = torch.load('pretrain_videomae_huge_patch16_224.pth', map_location='cpu')
# model.load_state_dict(checkpoint['model'])

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)
print(model)

In [None]:
# @title Feature extraction function

WINDOW_LENGTH = 16
PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % WINDOW_LENGTH: # VideoMAE requires 16 frames per clip
    video = np.pad(video, ((0, WINDOW_LENGTH - num_frames % WINDOW_LENGTH), (0, 0), (0, 0), (0, 0)))

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(3, 0, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(0, video.shape[1], WINDOW_LENGTH):
    feature = model(video[:, t : t + WINDOW_LENGTH][None])[0]  # (h * w + 1, c)
    feature = feature.view(WINDOW_LENGTH // 2, h, w, -1)  # (h, w, c)
    feature = F.interpolate(feature.permute(3, 0, 1, 2)[None], size=(WINDOW_LENGTH, h, w), mode='trilinear')
    features.append(feature[0].permute(1, 2, 3, 0))
  features = torch.concatenate(features, dim=0)  # (T, h, w, c)
  features = features[:num_frames]
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### VideoMAE v2

In [None]:
# @title Load VideoMAE v2 model

import transformers

def get_sinusoid_encoding_table(n_position, d_hid):
  def get_angle(pos):
    return [pos / np.power(10000, 2 * (i // 2) / d_hid) for i in range(d_hid)]
  table = np.array([get_angle(i) for i in range(n_position)])
  table[:, 0::2], table[:, 1::2] = np.sin(table[:, 0::2]), np.cos(table[:, 1::2])
  return torch.FloatTensor(table).unsqueeze(0)

class VideoMAEv2Config(transformers.configuration_utils.PretrainedConfig):
  model_type = 'VideoMAEv2_Base'
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

class Mlp(nn.Module):
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = act_layer()
    self.fc2 = nn.Linear(hidden_features, out_features)

  def forward(self, x):
    return self.fc2(self.act(self.fc1(x)))

class Attention(nn.Module):
  def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None):
    super().__init__()
    self.num_heads = num_heads
    head_dim = attn_head_dim or dim // num_heads
    all_head_dim = head_dim * self.num_heads
    self.scale = qk_scale or head_dim**-0.5
    self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
    if qkv_bias:
      self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
      self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
    else:
      self.q_bias = self.v_bias = None
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(all_head_dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)

  def forward(self, x):
    B, N, C = x.shape
    qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) if self.q_bias is not None else None
    qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]
    attn = (q * self.scale) @ k.transpose(-2, -1)
    attn = self.attn_drop(attn.softmax(dim=-1))
    x = self.proj((attn @ v).transpose(1, 2).reshape(B, N, -1))
    return self.proj_drop(x)

class Block(nn.Module):
  def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None, cos_attn=False):
    super().__init__()
    self.norm1, self.norm2 = norm_layer(dim), norm_layer(dim)
    self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, attn_head_dim)
    self.drop_path = nn.Identity()
    self.mlp = Mlp(dim, int(dim * mlp_ratio), act_layer=act_layer)
    if init_values > 0:
      self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
      self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
    else:
      self.gamma_1 = self.gamma_2 = None

  def forward(self, x):
    if self.gamma_1 is None:
      x = x + self.drop_path(self.attn(self.norm1(x)))
      x = x + self.drop_path(self.mlp(self.norm2(x)))
    else:
      x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
      x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
    return x

def to_2tuple(x): return (x, x) if not isinstance(x, tuple) else x

class PatchEmbed(nn.Module):
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
    super().__init__()
    img_size, patch_size = to_2tuple(img_size), to_2tuple(patch_size)
    num_spatial = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
    num_patches = num_spatial * (num_frames // tubelet_size)
    self.img_size, self.patch_size, self.num_patches, self.tubelet_size = img_size, patch_size, num_patches, tubelet_size
    self.proj = nn.Conv3d(in_chans, embed_dim, (tubelet_size, patch_size[0], patch_size[1]), (tubelet_size, patch_size[0], patch_size[1]))

  def forward(self, x):
    B, C, T, H, W = x.shape
    return self.proj(x).flatten(2).transpose(1, 2)

class VisionTransformer(nn.Module):
  def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., head_drop_rate=0., norm_layer=nn.LayerNorm, layer_norm_eps=1e-12, init_values=0., use_learnable_pos_emb=False, init_scale=0., num_frames=16, tubelet_size=2, use_mean_pooling=True, with_cp=False, cos_attn=False):
    super().__init__()
    self.num_classes, self.num_features, self.embed_dim, self.tubelet_size = num_classes, embed_dim, embed_dim, tubelet_size
    self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim, num_frames, tubelet_size)
    self.patch_size = patch_size
    num_patches = self.patch_embed.num_patches
    self.embedding_shape = (8, 14, 14)
    self.with_cp = with_cp
    norm_layer = functools.partial(eval(norm_layer), eps=layer_norm_eps)
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if use_learnable_pos_emb else get_sinusoid_encoding_table(num_patches, embed_dim)
    self.pos_drop = nn.Dropout(p=drop_rate)
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
    self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, dpr[i], norm_layer=norm_layer, init_values=init_values, cos_attn=cos_attn) for i in range(depth)])
    self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
    self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
    self.head_dropout = nn.Dropout(head_drop_rate)
    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
    if use_learnable_pos_emb:
      nn.init.trunc_normal_(self.pos_embed, std=.02)

  def interpolate_pos_encoding(self, x, h, w):
    x = x.reshape(self.embedding_shape + (-1,))
    dim = x.shape[-1]
    x = F.interpolate(
        x.permute(0, 3, 1, 2),
        scale_factor=(h / self.embedding_shape[-2], w / self.embedding_shape[-1]),
        mode="bicubic",
    )
    x = x.permute(0, 2, 3, 1).view(1, -1, dim)
    return x

  def forward(self, x):
    B = x.size(0)
    _, _, _, h, w = x.shape
    x = self.patch_embed(x)
    # x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
    pos_embed = self.pos_embed.type_as(x).to(x.device).clone().detach()
    h = h // self.patch_size
    w = w // self.patch_size
    x = x + self.interpolate_pos_encoding(pos_embed, h, w)
    x = self.pos_drop(x)
    for blk in self.blocks:
      x = blk(x)
    return self.fc_norm(x)


class VideoMAEv2(transformers.PreTrainedModel):
  config_class = VideoMAEv2Config
  def __init__(self, config=None):
    super().__init__(config=config)
    self.model_config = config.model_config
    self.model = VisionTransformer(**self.model_config)

  def forward(self, video):
    return self.model(video)

model = VideoMAEv2.from_pretrained('OpenGVLab/VideoMAEv2-Large')

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)
print(model)

In [None]:
# @title Feature extraction function

WINDOW_LENGTH = 16
PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % WINDOW_LENGTH: # VideoMAE requires 16 frames per clip
    video = np.pad(video, ((0, WINDOW_LENGTH - num_frames % WINDOW_LENGTH), (0, 0), (0, 0), (0, 0)))

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(3, 0, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(0, video.shape[1], WINDOW_LENGTH):
    feature = model(video[:, t : t + WINDOW_LENGTH][None])[0]  # (h * w + 1, c)
    feature = feature.view(WINDOW_LENGTH // 2, h, w, -1)  # (h, w, c)
    feature = F.interpolate(feature.permute(3, 0, 1, 2)[None], size=(WINDOW_LENGTH, h, w), mode='trilinear')
    features.append(feature[0].permute(1, 2, 3, 0))
  features = torch.concatenate(features, dim=0)  # (T, h, w, c)
  features = features[:num_frames]
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### V-JEPA

In [None]:
# @title Download V-JEPA code
%cd /content
!git clone https://github.com/facebookresearch/jepa.git
# !pip install -e jepa
# !pip install timm einops torchcodec torchvision torch==2.6.0
%cd /content/jepa

In [None]:
# @title Download V-JEPA checkpoints
!mkdir checkpoints
!wget -P checkpoints https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar

In [None]:
# @title Load V-JEPA encoder

%cd /content/jepa

import src.models.vision_transformer as vit

model = vit.VisionTransformer(
    patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
    norm_layer=functools.partial(nn.LayerNorm, eps=1e-6), img_size=224, num_frames=16, tubelet_size=2,
    uniform_power=False, use_sdpa=False, use_SiLU=False, tight_SiLU=True)

checkpoint = torch.load('checkpoints/vitl16.pth.tar', map_location='cpu')
pretrained_dict = checkpoint['target_encoder']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
model.load_state_dict(pretrained_dict, strict=False)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)
print(model)

In [None]:
# @title Load V-JEPA encoder

%cd /content/jepa

from src.models.utils.patch_embed import PatchEmbed3D
from src.models.utils.modules import Block

class VisionTransformer(nn.Module):

  def __init__(self, img_size=224, patch_size=16, num_frames=16, tubelet_size=2, in_chans=3,
              embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None,
              norm_layer=nn.LayerNorm):
    super().__init__()

    self.input_size = img_size
    self.patch_size = patch_size

    self.num_frames = num_frames
    self.tubelet_size = tubelet_size
    grid_size, grid_depth = img_size // patch_size, num_frames // tubelet_size

    self.patch_embed = PatchEmbed3D(patch_size=patch_size, tubelet_size=tubelet_size,
                                    in_chans=in_chans, embed_dim=embed_dim)
    self.num_patches = grid_depth * grid_size * grid_size
    self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=False)

    self.blocks = nn.ModuleList([
        Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
              qk_scale=qk_scale, drop=0.0, attn_drop=0.0, norm_layer=norm_layer,
              act_layer=nn.GELU, grid_size=grid_size, grid_depth=grid_depth)
        for _ in range(depth)
    ])
    self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

  def forward(self, x):
    x = self.patch_embed(x) + self.interpolate_pos_encoding(x, self.pos_embed)
    for _, block in enumerate(self.blocks):
      x = block(x, mask=None)
    x = self.norm(x)
    return x

  def interpolate_pos_encoding(self, x, pos_embed):
    _, N, dim = pos_embed.shape
    _, _, T, H, W = x.shape
    if H == self.input_size and W == self.input_size and T == self.num_frames:
      return pos_embed

    T, H, W = T // self.tubelet_size, H // self.patch_size, W // self.patch_size
    N_t = self.num_frames // self.tubelet_size
    N_h = N_w = self.input_size // self.patch_size

    pos_embed = F.interpolate(
        pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3),
        scale_factor=(T/N_t, H/N_h, W/N_w),
        mode='trilinear')
    pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim)
    return pos_embed

model = VisionTransformer(
    patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
    norm_layer=functools.partial(nn.LayerNorm, eps=1e-6), img_size=224, num_frames=16, tubelet_size=2)

checkpoint = torch.load('checkpoints/vitl16.pth.tar', map_location='cpu')
pretrained_dict = checkpoint['target_encoder']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
model.load_state_dict(pretrained_dict, strict=False)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)
print(model)

In [None]:
# @title Feature extraction function

WINDOW_LENGTH = 16
PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % WINDOW_LENGTH: # VideoMAE requires 16 frames per clip
    video = np.pad(video, ((0, WINDOW_LENGTH - num_frames % WINDOW_LENGTH), (0, 0), (0, 0), (0, 0)))

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(3, 0, 1, 2).float().cuda()  # (3, T, th, tw)
  features = []
  for t in range(0, video.shape[1], WINDOW_LENGTH):
    feature = model(video[:, t : t + WINDOW_LENGTH][None])[0]  # (h * w + 1, c)
    feature = feature.view(WINDOW_LENGTH // 2, h, w, -1)  # (h, w, c)
    feature = F.interpolate(feature.permute(3, 0, 1, 2)[None], size=(WINDOW_LENGTH, h, w), mode='trilinear')
    features.append(feature[0].permute(1, 2, 3, 0))
  features = torch.concatenate(features, dim=0)  # (T, h, w, c)
  features = features[:num_frames]
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### V-JEPA 2

In [None]:
# @title Load V-JEPA 2 model

import transformers

class VJEPA2Model(transformers.VJEPA2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.encoder = transformers.models.vjepa2.modeling_vjepa2.VJEPA2Encoder(config)
    self.predictor = transformers.models.vjepa2.modeling_vjepa2.VJEPA2Predictor(config)

  def forward(self, video):
    encoder_outputs = self.encoder(
        pixel_values_videos=video,
        head_mask=None,
    )
    sequence_output = encoder_outputs.last_hidden_state

    return sequence_output

model = VJEPA2Model.from_pretrained("facebook/vjepa2-vitl-fpc64-256").to("cuda")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Load V-JEPA 2 encoder

from transformers.models.vjepa2.modeling_vjepa2 import VJEPA2MLP
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

class VJEPA2PatchEmbeddings3D(nn.Module):
  def __init__(self, config, hidden_size: int = 1024):
    super().__init__()
    self.patch_size = config.patch_size
    self.tubelet_size = config.tubelet_size
    self.hidden_size = hidden_size

    self.proj = nn.Conv3d(
        in_channels=config.in_chans,
        out_channels=hidden_size,
        kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
        stride=(config.tubelet_size, config.patch_size, config.patch_size),
    )

  def forward(self, video):
    x = self.proj(video)
    num_frames, height, width = x.shape[-3:]
    x = x.flatten(2).transpose(1, 2)
    return x, num_frames, height, width

class VJEPA2Embeddings(nn.Module):
  def __init__(self, config, hidden_size: int = 1024):
    super().__init__()
    self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)

  def forward(self, video):
    video = video.permute(0, 2, 1, 3, 4)
    embeddings, num_frames, height, width = self.patch_embeddings(video)
    return embeddings, num_frames, height, width

def eager_attention_forward(module, query, key, value, attention_mask, scaling):
  attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  if attention_mask is not None:
    attn_weights = attn_weights * attention_mask
  attn_output = torch.matmul(attn_weights, value)
  attn_output = attn_output.transpose(1, 2).contiguous()
  return attn_output, attn_weights

def rotate_queries_or_keys(x, pos):
  B, num_heads, N, D = x.size()
  omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
  omega /= D / 2.0
  omega = 1.0 / 10000**omega  # (D/2,)
  freq = pos.unsqueeze(-1) * omega  # (..., N, D/2), outer product
  emb_sin = freq.sin()  # (..., N, D/2)
  emb_cos = freq.cos()  # (..., N, D/2)
  emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
  emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
  y = x.unflatten(-1, (-1, 2))
  y1, y2 = y.unbind(dim=-1)
  y = torch.stack((-y2, y1), dim=-1)
  y = y.flatten(-2)
  return (x * emb_cos) + (y * emb_sin)

class VJEPA2RopeAttention(nn.Module):
  def __init__(self, config, hidden_size=1024, num_attention_heads=16):
    super().__init__()
    self.config = config
    self.hidden_size = hidden_size
    self.num_attention_heads = num_attention_heads
    self.attention_head_size = hidden_size // num_attention_heads

    self.query = nn.Linear(hidden_size, hidden_size, bias=config.qkv_bias)
    self.key = nn.Linear(hidden_size, hidden_size, bias=config.qkv_bias)
    self.value = nn.Linear(hidden_size, hidden_size, bias=config.qkv_bias)
    self.proj = nn.Linear(hidden_size, hidden_size)

    self.crop_size = self.config.crop_size
    self.grid_size = self.config.crop_size // self.config.patch_size

    self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
    self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
    self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))

    self.scaling = self.attention_head_size**-0.5
    self.is_causal = False

  def _get_frame_pos(self, ids):
    tokens_per_frame = self.grid_size ** 2
    return ids // tokens_per_frame

  def _get_height_pos(self, ids):
    return (ids % (self.grid_size ** 2)) // self.grid_size

  def get_position_ids(self, x, height, width):
    device = x.device
    token_size = x.size(1)
    num_frames = token_size // height // width
    ids = torch.arange(token_size, device=device)
    # Original position ids
    # frame_ids = self._get_frame_pos(ids)
    # height_ids = self._get_height_pos(ids)
    # width_ids = ids - self.grid_size ** 2 * frame_ids - self.grid_size * height_ids

    # Interpolated position ids
    frame_ids = ids // (height * width)
    iy = torch.arange(self.grid_size, device=device, dtype=torch.float32)
    ix = torch.arange(self.grid_size, device=device, dtype=torch.float32)
    iy, ix = torch.meshgrid(iy, ix, indexing='ij')
    height_ids = F.interpolate(iy[None, None], size=(height, width), mode='bilinear')[0, 0]
    width_ids = F.interpolate(ix[None, None], size=(height, width), mode='bilinear')[0, 0]
    height_ids = height_ids.flatten()
    width_ids = width_ids.flatten()
    height_ids = height_ids.repeat(num_frames)
    width_ids = width_ids.repeat(num_frames)
    return frame_ids, height_ids, width_ids

  def apply_rotary_embeddings(self, qk, pos_ids):
    d, h, w = pos_ids
    s = 0
    qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d);
    s += self.d_dim
    qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h);
    s += self.h_dim
    qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w);
    s += self.w_dim
    if s < self.attention_head_size:
      qkr = qk[..., s:]
      qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
    else:
      qk = torch.cat([qkd, qkh, qkw], dim=-1)
    return qk

  def forward(self, hidden_states, height, width, head_mask=None):
    batch_size, seq_length, _ = hidden_states.shape
    query = self.query(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
    key = self.key(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
    value = self.value(hidden_states).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)

    pos_ids = self.get_position_ids(hidden_states, height, width)
    key = self.apply_rotary_embeddings(key, pos_ids)
    query = self.apply_rotary_embeddings(query, pos_ids)

    attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
    context, _ = attn_fn(self, query, key, value, head_mask, is_causal=self.is_causal, scaling=self.scaling, dropout=0.0)
    context = context.reshape(context.size()[:-2] + (self.hidden_size,))
    outputs = self.proj(context)
    return outputs

class VJEPA2Layer(nn.Module):
  def __init__(self, config, hidden_size=1024, num_attention_heads=16, mlp_ratio=4.0):
    super().__init__()
    self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
    self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
    self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
    self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)

  def forward(self, hidden_states, height, width):
    x = self.norm1(hidden_states)
    attn_out = self.attention(x, height, width, head_mask=None)
    x = attn_out + hidden_states
    x_res = self.norm2(x)
    x = self.mlp(x_res) + x
    return x

class VJEPA2Encoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
    self.layer = nn.ModuleList([
        VJEPA2Layer(config, config.hidden_size, config.num_attention_heads, config.mlp_ratio)
        for _ in range(config.num_hidden_layers)
    ])
    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

  def forward(self, video):
    x, num_frames, height, width = self.embeddings(video)
    for layer in self.layer:
      x = layer(x, height, width)
    x = self.layernorm(x)
    return x

class VJEPA2Model(transformers.VJEPA2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    print(config)
    self.encoder = VJEPA2Encoder(config)
    self.predictor = transformers.models.vjepa2.modeling_vjepa2.VJEPA2Predictor(config)

  def forward(self, video):
    encoder_outputs = self.encoder(video)
    return encoder_outputs

model = VJEPA2Model.from_pretrained("facebook/vjepa2-vitl-fpc64-256").to("cuda")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Feature extraction function

def extract_features_whole(model, video, h, w):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % 2: # V-JEPA 2 requires even frames per clip
    video = np.pad(video, ((0, 2 - num_frames % 2), (0, 0), (0, 0), (0, 0)))

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = model(video[None])[0]
  features = features.reshape(video.shape[0] // 2, h, w, features.shape[-1])
  features = F.interpolate(features.permute(3, 0, 1, 2)[None], size=(video.shape[0], h, w), mode='trilinear')
  features = features[0].permute(1, 2, 3, 0)
  features = features[:num_frames]
  return features

WINDOW_LENGTH = 16
PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % WINDOW_LENGTH: # VideoMAE requires 16 frames per clip
    video = np.pad(video, ((0, WINDOW_LENGTH - num_frames % WINDOW_LENGTH), (0, 0), (0, 0), (0, 0)))
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (3, T, th, tw)
  features = []
  for t in range(0, video.shape[0], WINDOW_LENGTH):
    feature = model(video[t : t + WINDOW_LENGTH][None])[0]  # (S * h * w, c)
    feature = feature.view(WINDOW_LENGTH // 2, h, w, feature.shape[-1])  # (S, h, w, c)
    feature = F.interpolate(feature.permute(3, 0, 1, 2)[None], size=(WINDOW_LENGTH, h, w), mode='trilinear')
    features.append(feature[0].permute(1, 2, 3, 0))
  features = torch.concatenate(features, dim=0)  # (T, h, w, c)
  features = features[:num_frames]
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### DINO

In [None]:
# @title Download DINO code

%cd /content
!git clone https://github.com/facebookresearch/dino.git
%cd /content/dino

In [None]:
# @title Load DINO model

%cd /content/dino

import vision_transformer as vits

model = vits.__dict__['vit_base'](patch_size=16, num_classes=0)
model.cuda()
# url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
# url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
# url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
torch.set_grad_enabled(False)
model.eval()

In [None]:
# @title Feature extraction function

PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(video.shape[0]):
    feature = model.get_intermediate_layers(video[t:t+1], n=1)[0]  # (h * w + 1, c)
    feature = feature[0, 1:, :].view(h, w, -1)  # discard the [CLS] token, (h * w, c)
    features.append(feature)
  features = torch.stack(features, dim=0)  # (T, h, w, c)
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### DINO v2

In [None]:
# @title Load DINO v2 model

import transformers
from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings, Dinov2Encoder

class Dinov2Model(transformers.Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.embeddings = Dinov2Embeddings(config)
    self.encoder = Dinov2Encoder(config)
    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

  def forward(self, image):
    embedding_output = self.embeddings(image, bool_masked_pos=None)
    encoder_outputs = self.encoder(
        embedding_output,
        head_mask=None,
    )
    sequence_output = self.layernorm(encoder_outputs[0])
    return sequence_output

model = Dinov2Model.from_pretrained("facebook/dinov2-large")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Feature extraction function

PATCH_SIZE = 14

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(video.shape[0]):
    feature = model(video[t:t+1])[0]  # (h * w + 1, c)
    feature = feature[1:, :].view(h, w, feature.shape[-1])  # discard the [CLS] token, (h * w, c)
    features.append(feature)
  features = torch.stack(features, dim=0)  # (T, h, w, c)
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=14)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=14)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=14)

### DINO v2 with registers

In [None]:
# @title Load DINO v2 with registers model

import transformers
from transformers.models.dinov2_with_registers.modeling_dinov2_with_registers import Dinov2WithRegistersEmbeddings, Dinov2WithRegistersEncoder

class Dinov2WithRegistersModel(transformers.Dinov2WithRegistersPreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.embeddings = Dinov2WithRegistersEmbeddings(config)
    self.encoder = Dinov2WithRegistersEncoder(config)
    self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

  def forward(self, image):
    embedding_output = self.embeddings(image, bool_masked_pos=None)
    encoder_outputs = self.encoder(
        embedding_output,
        head_mask=None,
    )
    sequence_output = self.layernorm(encoder_outputs[0])
    patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
    return patch_tokens

model = Dinov2WithRegistersModel.from_pretrained("facebook/dinov2-with-registers-large")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Feature extraction function

PATCH_SIZE = 14

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(video.shape[0]):
    feature = model(video[t:t+1])[0]  # (h * w + 1, c)
    feature = feature.view(h, w, feature.shape[-1])  # discard the [CLS] token, (h * w, c)
    features.append(feature)
  features = torch.stack(features, dim=0)  # (T, h, w, c)
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=14)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=14)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=14)

### CropMAE

In [None]:
# @title Download CropMAE code

%cd /content
!git clone https://github.com/alexandre-eymael/CropMAE.git
%cd /content/CropMAE

In [None]:
# @title Download CropMAE checkpoints

%cd /content/CropMAE

import gdown

gdown.download('https://drive.google.com/uc?id=1Hwxpck0MGBJkPNpXRxMOyydtehhhtZbj', '/content/CropMAE/cropmae_in_new.pth', quiet=False)
gdown.download('https://drive.google.com/uc?id=1oMXiX_uyGzyQB7S-MYkdJvKFmIuPXkYb', '/content/CropMAE/cropmae_k400.pth', quiet=False)

In [None]:
# @title Load CropMAE model

%cd /content/CropMAE

from timm.models.vision_transformer import Block

class PatchEmbed(nn.Module):
  """ Image to Patch Embedding"""
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
    super().__init__()
    num_patches = (img_size // patch_size) * (img_size // patch_size)
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

  def forward(self, x):
    x = self.proj(x).flatten(2).transpose(1, 2)
    return x

class MaskedAutoencoderViT(nn.Module):
  """ Masked Autoencoder with VisionTransformer backbone"""
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, ckpt_path=None):
    super().__init__()

    # --------------------------------------------------------------------------
    # MAE encoder specifics
    self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
    num_patches = self.patch_embed.num_patches

    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

    self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)])
    self.norm = norm_layer(embed_dim)

  def interpolate_pos_encoding(self, x, w, h):
    npatch = x.shape[1] - 1
    N = self.pos_embed.shape[1] - 1
    if npatch == N and w == h:
      return self.pos_embed
    class_pos_embed = self.pos_embed[:, 0]
    patch_pos_embed = self.pos_embed[:, 1:]
    dim = x.shape[-1]
    w0 = w // self.patch_embed.patch_size
    h0 = h // self.patch_embed.patch_size
    # we add a small number to avoid floating point error in the interpolation
    # see discussion at https://github.com/facebookresearch/dino/issues/8
    w0, h0 = w0 + 0.1, h0 + 0.1
    patch_pos_embed = nn.functional.interpolate(
        patch_pos_embed.reshape(
            1, int(math.sqrt(N)), int(math.sqrt(N)), dim
        ).permute(0, 3, 1, 2),
        scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
        mode="bicubic",
    )
    assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

  def forward(self, x):
    B, nc, w, h = x.shape
    x = self.patch_embed(x)  # patch linear embedding

    # add the [CLS] token to the embed patch tokens
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # add positional encoding to each token
    x = x + self.interpolate_pos_encoding(x, w, h)

    for blk in self.blocks:
      x = blk(x)
    x = self.norm(x)
    return x


model = MaskedAutoencoderViT(
    patch_size=16, embed_dim=384, depth=12, num_heads=6,
    mlp_ratio=4, norm_layer=functools.partial(nn.LayerNorm, eps=1e-6))

ckpt = torch.load('/content/CropMAE/cropmae_k400.pth', weights_only=False)
pretrain_state = ckpt["model"]
model_state = model.state_dict()
match_state = {k: v for k, v in pretrain_state.items() if k in model_state}
model_state.update(match_state)
model.load_state_dict(model_state)

model = model.cuda()
model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Feature extraction function

PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = (video - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  torch.cuda.empty_cache()
  video = torch.tensor(video).permute(0, 3, 1, 2).float().cuda()  # (T, 3, th, tw)
  features = []
  for t in range(video.shape[0]):
    feature = model(video[t:t+1])[0]  # (h * w + 1, c)
    feature = feature[1:].view(h, w, feature.shape[-1])  # (h, w, c)
    features.append(feature)
  features = torch.stack(features, dim=0)  # (T, h, w, c)
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, video)
  features = features.cpu().numpy()
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16)

### 4DS-VideoMAE

In [None]:
# @title Download 4DS-VideoMAE code

%cd /content
!git clone https://github.com/google-deepmind/representations4d.git
%cd /content/representations4d
!pip install .

In [None]:
# @title Download checkpoint

%cd /content/representations4d

!wget https://storage.googleapis.com/representations4d/checkpoints/scaling4d_dist_b_depth.npz
# !wget https://storage.googleapis.com/representations4d/checkpoints/scaling4d_e.npz

In [None]:
# @title Define model

from flax import linen as nn
from kauldron import kd
from kauldron import kontext
from kauldron.modules import pos_embeddings, vit as kd_vit, attention, transformers
from kauldron.typing import DType, Initializer

class LearnedEmbedding(nn.Module):
  dtype: DType = jnp.float32
  emb_init: Initializer = nn.initializers.normal(stddev=0.02)  # From BERT.
  emb_name: str = 'embeddings'
  embedding_shape: tuple = (8, 14, 14)

  @nn.compact
  def __call__(self, shape, *, axis):
    emb_shape = self.embedding_shape + (shape[-1],)
    pe = self.param(self.emb_name, self.emb_init, emb_shape, self.dtype)
    h, w = self.embedding_shape[-2], self.embedding_shape[-1]
    *b, tokens_h, tokens_w, d = shape
    for _ in range(len(b)-1):
      pe = jnp.expand_dims(pe, axis=0)
    if tokens_h != h or tokens_w != w:
      pe = jax.image.resize(pe, (*b, tokens_h, tokens_w, d), method='bicubic')
    return pe

class GeneralizedTransformer(nn.Module):
  layers: list
  n_iter: int = 1

  def __call__(self, tokens):
    aux = [jnp.reshape(tokens, tokens.shape)]
    latent_state = tokens
    for h in range(self.n_iter):
      if h > 0:
        latent_state = jnp.concatenate([latent_state, tokens], axis=-2)
      for layer in self.layers:
        if h == self.n_iter - 1:
          aux.append(latent_state)
        latent_state = layer(latent_state)
        latent_state = jnp.reshape(latent_state, [latent_state.shape[0], -1, latent_state.shape[-1]])
    return aux

  @classmethod
  def from_variant_str(cls, variant_str, **kwargs):
    vit_spec = kd_vit.ViTSpec.from_variant_string(variant_str)
    all_kwargs = vit_spec.kwargs | kwargs
    all_kwargs.pop('patch_size', None)
    all_kwargs.pop('hidden_size', None)
    return cls.from_spec(**all_kwargs)

  @classmethod
  def from_spec(cls, num_heads, num_layers, mlp_size=None, qk_features=None, v_features=None, **kwargs):
    blocks = []
    for _ in range(num_layers):
      blocks.append(
        transformers.PreNormBlock(
          attention_norm=nn.LayerNorm(),
          mlp_norm=nn.LayerNorm(),
          attention=attention.ImprovedMultiHeadDotProductAttention(
            num_heads=num_heads,
            qk_features=qk_features,
            v_features=v_features,
            kernel_init=nn.initializers.lecun_normal()),
          mlp=transformers.TransformerMLP(hidden_size=mlp_size, kernel_init=nn.initializers.xavier_uniform()),
        )
      )
    return cls(layers=tuple(blocks), **kwargs)


class Model(nn.Module):
  encoder: nn.Module
  processor: nn.Module

  def __call__(self, video):
    tokens = self.encoder(video)
    features = self.processor(tokens)
    return features[-1]


class Tokenizer(nn.Module):
  patch_embedding: nn.Module
  posenc: nn.Module

  def __call__(self, images):
    tokens = self.patch_embedding(images)
    posenc = self.posenc(tokens.shape, axis=(-4, -3, -2))
    tokens += posenc
    tokens = einops.rearrange(tokens, '... T h w D -> ... (T h w) D')
    return tokens


class PatchEmbedding(nn.Module):
  patch_size: tuple
  num_features: int

  @nn.compact
  def __call__(self, images):
    return nn.Conv(features=self.num_features, kernel_size=self.patch_size, strides=self.patch_size, padding='VALID')(images)


class EncoderToReadout(nn.Module):
  embedding_shape: tuple
  readout_depth: int
  num_input_frames: int

  def __call__(self, all_features):
    readout_id = int(len(all_features) * self.readout_depth) - 1
    features = all_features[readout_id]
    readout_features = jnp.reshape(features,
      (features.shape[0], self.embedding_shape[0],
       self.embedding_shape[1] * self.embedding_shape[2], features.shape[-1])
    )
    out_shape = (readout_features.shape[0], self.num_input_frames,
                 self.embedding_shape[0] * self.embedding_shape[1] * self.embedding_shape[2] // self.embedding_shape[0],
                 readout_features.shape[3])
    readout_features = jax.image.resize(readout_features, out_shape, jax.image.ResizeMethod.CUBIC)
    return readout_features


class MLP(nn.Module):
  hidden_size: int

  @nn.compact
  def __call__(self, x):
    out_dim = x.shape[-1]
    x = nn.gelu(nn.Dense(self.hidden_size)(x))
    return nn.Dense(out_dim)(x)


class AttentionReadout(nn.Module):
  num_classes: int
  num_params: int
  num_heads: int
  num_queries: int

  @nn.compact
  def __call__(self, inputs):
    feats = nn.LayerNorm()(inputs) # (1, 16, 196, 768)
    feats += kd.nn.LearnedEmbedding(name='temporal_posenc')(feats.shape, axis=-3)
    feats = einops.rearrange(feats, '... T N C -> ... (T N) C') # (1, 3136, 768)
    query = self.param('query', nn.initializers.normal(0.02), [self.num_queries, self.num_heads, self.num_params // self.num_heads]) # (6272, 16, 64)
    query = jnp.broadcast_to(query, (feats.shape[0],) + query.shape) # (1, 6272, 16, 64)
    key = nn.DenseGeneral(features=(self.num_heads, self.num_params // self.num_heads), axis=-1, use_bias=True, name='key_embedding')(feats) # (1, 3136, 16, 64)
    val = nn.DenseGeneral(features=(self.num_heads, self.num_params // self.num_heads), axis=-1, use_bias=True, name='value_embedding')(feats) # (1, 3136, 16, 64)
    token = nn.dot_product_attention(query=query, key=key, value=val) # (1, 6272, 16, 64)
    token = einops.rearrange(token, '... Q N c -> ... Q (N c)') # (1, 6272, 1024)
    query = einops.rearrange(query, '... Q N c -> ... Q (N c)') # (1, 6272, 1024)
    token = query + nn.Dense(self.num_params)(token)
    token = token + MLP(self.num_params * 4)(nn.LayerNorm()(token))
    out = nn.Dense(self.num_classes)(token) # (1, 6272, 128)
    return out


model = nn.Sequential([
    Model(
        encoder=Tokenizer(
            patch_embedding=PatchEmbedding(patch_size=(2,16,16), num_features=768),
            posenc=LearnedEmbedding()),
        processor=GeneralizedTransformer.from_variant_str(variant_str="B")),
        # processor=GeneralizedTransformer.from_variant_str(variant_str="e")),
    # EncoderToReadout(embedding_shape=(8,14,14), readout_depth=0.95, num_input_frames=16),
    # AttentionReadout(num_classes=128, num_params=1024, num_heads=16, num_queries=6272)
])

In [None]:
# @title Load checkpoint

def recover_tree(flat_dict):
  tree = {}
  for k, v in flat_dict.items():
    parts = k.split("/")
    node = tree
    for part in parts[:-1]:
      if part not in node:
        node[part] = {}
      node = node[part]
    node[parts[-1]] = v
  return tree

restored_params = recover_tree(np.load("scaling4d_dist_b_depth.npz", allow_pickle=False))
# restored_params = recover_tree(np.load("scaling4d_e.npz", allow_pickle=False))

In [None]:
# @title Feature extraction function

WINDOW_LENGTH = 16
PATCH_SIZE = 16

def extract_features(model, params, video):
  video = video.astype(np.float32) / 255.0
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  num_frames = video.shape[0]
  if num_frames % WINDOW_LENGTH:
    video = jnp.pad(video, ((0, WINDOW_LENGTH - num_frames % WINDOW_LENGTH), (0, 0), (0, 0), (0, 0)))

  def forward(params, video):
    return model.apply(params, video, is_training_property=False)

  features = []
  for t in range(0, num_frames, WINDOW_LENGTH):
    outputs = forward(params, video[t : t + WINDOW_LENGTH][None])
    feature = outputs[0].reshape(WINDOW_LENGTH // 2, h, w, outputs[0].shape[-1])
    feature = jax.image.resize(feature, (WINDOW_LENGTH, h, w, feature.shape[-1]), jax.image.ResizeMethod.CUBIC)
    features.append(feature)
  features = np.concatenate(features, axis=0)
  features = features[:num_frames]
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, restored_params, video)
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16, params=restored_params)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16, params=restored_params)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16, params=restored_params)

### RVM

In [None]:
# @title Download checkpoint

%mkdir /content/rvm
%cd /content/rvm

!wget https://storage.googleapis.com/dm-tapnet/tmp/pretrain_rvm_large16_256_xid175558463_wid1.npz

In [None]:
# @title Load checkpoint

%cd /content/rvm

def recover_tree(flat_dict):
  tree = {}
  for k, v in flat_dict.items():
    parts = k.split("/")
    node = tree
    for part in parts[:-1]:
      if part not in node:
        node[part] = {}
      node = node[part]
    node[parts[-1]] = v
  return tree

restored_params = recover_tree(np.load("pretrain_rvm_large16_256_xid175558463_wid1.npz", allow_pickle=False))

In [None]:
# @title Define model {form-width: "20%"}

import dataclasses
import re
from typing import Any
import jax.numpy as jnp
from flax import linen as nn

class PatchEmbedding(nn.Module):
  """Extracts patches with a single learned linear projection."""
  patch_size: list[int]
  num_features: int

  @nn.compact
  def __call__(self, images):
    return nn.Conv(features=self.num_features, kernel_size=self.patch_size, strides=self.patch_size, padding='VALID')(images)

def get_mae_sinusoid_encoding_table(n_position, d_hid, dtype=jnp.float32):
  """Sinusoid positional encoding table for MAE."""
  def get_position_angle_vec(position):
    return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

  return jnp.asarray(sinusoid_table, dtype)[None, ...]

class SincosPosEmb(nn.Module):
  """Returns sinusoidal positional embedding given the shape of the tokens."""
  base_token_shape: list[int] | None = None

  @nn.compact
  def __call__(self, tokens_shape):
    d = tokens_shape[-1]
    if self.base_token_shape is not None:
      h, w = self.base_token_shape
    else:
      h, w = tokens_shape[-3], tokens_shape[-2]

    posenc = get_mae_sinusoid_encoding_table(np.prod((h, w)), d)
    posenc = einops.rearrange(posenc, '... (h w) d -> ... h w d', h=h, w=w)
    *b, tokens_h, tokens_w, _ = tokens_shape
    for _ in range(len(b)-1):
      posenc = jnp.expand_dims(posenc, axis=0)
    if tokens_h != h or tokens_w != w:
      posenc = jax.image.resize(posenc, (*b, tokens_h, tokens_w, d), method='bicubic')

    return posenc

class Tokenizer(nn.Module):
  """Simple tokenizer."""
  patch_embedding: nn.Module
  posenc: nn.Module

  @nn.compact
  def __call__(self, image):
    tokens = self.patch_embedding(image)
    posenc = self.posenc(tokens.shape)
    tokens += posenc
    return tokens

class TransformerMLP(nn.Module):
  """Simple MLP with a single hidden layer for use in Transformer blocks."""
  hidden_size: int = None  # Defaults to 4 times input dims.

  @nn.compact
  def __call__(self, inputs):
    d = inputs.shape[-1]
    hidden_size = 4 * d if self.hidden_size is None else self.hidden_size
    h = nn.Dense(
        features=hidden_size,
        name='dense_in',
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.zeros,
        dtype=inputs.dtype,
    )(inputs)
    h = nn.gelu(h)
    return nn.Dense(
        features=inputs.shape[-1],
        name='dense_out',
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.zeros,
        dtype=h.dtype,
    )(h)

class PreNormBlock(nn.Module):
  """Pre-LN Transformer layer (default transformer layer)."""

  attention: Any
  mlp: nn.Module
  attention_norm: Any
  mlp_norm: Any

  @nn.compact
  def __call__(self, tokens):
    norm_tokens = self.attention_norm(tokens)
    tokens += self.attention(
        inputs_q=norm_tokens,
        inputs_k=norm_tokens,
        inputs_v=norm_tokens,
    )
    norm_tokens = self.mlp_norm(tokens)
    return tokens + self.mlp(norm_tokens)

VIT_SIZES = {
    'mu': (32, 1, 128, 2),
    'Ti': (192, 12, 768, 3),
    'S': (384, 12, 1536, 6),
    'M': (512, 12, 2048, 8),
    'B': (768, 12, 3072, 12),
    'L': (1024, 24, 4096, 16),
    'H': (1280, 32, 5120, 16),
    'g': (1408, 40, 6144, 16),
    'G': (1664, 48, 8192, 16),
    'e': (1792, 56, 15360, 16),
}

@dataclasses.dataclass(frozen=True)
class ViTSpec:
  """Spec for the size of a Vision Transformer."""

  hidden_size: int  # Dimension of tokens passed between blocks.
  num_layers: int  # Number of trasformer blocks.
  mlp_size: int  # Hidden dimension of the MLP in each block.
  num_heads: int  # Number of attention heads.
  patch_size: int = None  # Patch size of initial image patches.

  @classmethod
  def from_variant_string(cls, variant_str: str):
    """Parse variant strings like "ViT-L", "B", or "Ti/16"."""
    r = re.match(
        r'^([Vv][Ii][Tt][-_])?(?P<name>[a-zA-Z]{1,2})(/(?P<patch>\d+))?$',
        variant_str,
    )
    if r is None:
      raise ValueError(f'Invalid variant string: {variant_str!r}.')
    name = r.groupdict()['name']
    spec = cls(*VIT_SIZES[name])

    patch_size = r.groupdict()['patch']
    if patch_size is not None:
      spec = dataclasses.replace(spec, patch_size=int(patch_size))
    return spec

  @property
  def kwargs(self):
    kwargs = dict(
        hidden_size=self.hidden_size,
        num_layers=self.num_layers,
        mlp_size=self.mlp_size,
        num_heads=self.num_heads,
        patch_size=self.patch_size,
    )
    if self.patch_size is None:
      del kwargs['patch_size']
    return kwargs

class Transformer(nn.Module):
  """Simple transformer model."""
  layers: tuple[Any]

  @nn.compact
  def __call__(self, tokens):
    for layer in self.layers:
      tokens = layer(tokens)
    tokens = nn.LayerNorm(dtype=tokens.dtype)(tokens)
    return tokens

  @classmethod
  def from_variant_str(cls, variant_str: str, **kwargs):
    vit_spec = ViTSpec.from_variant_string(variant_str)
    all_kwargs = vit_spec.kwargs | kwargs
    all_kwargs.pop('patch_size', None)
    all_kwargs.pop('hidden_size', None)
    return cls.from_spec(**all_kwargs)

  @classmethod
  def from_spec(
      cls,
      num_heads: int,
      num_layers: int,
      mlp_size = None,
      dtype=jnp.float32,
      qk_features = None,
      v_features = None,
      **kwargs,
  ):
    return cls(
        layers=tuple(
            PreNormBlock(
                attention_norm=nn.LayerNorm(dtype=dtype),
                mlp_norm=nn.LayerNorm(dtype=dtype),
                attention=ImprovedMultiHeadDotProductAttention(
                    num_heads=num_heads,
                    qk_features=qk_features,
                    v_features=v_features,
                ),
                mlp=TransformerMLP(hidden_size=mlp_size),
            )
            for _ in range(num_layers)
        ),
        **kwargs,
    )

class GatedTransformerCore(nn.Module):
  transformer: nn.Module
  initializer: nn.Module
  token_dim: int
  state_layer_norm: nn.Module

  def setup(self):
    self.input_update = nn.Dense(self.token_dim, use_bias=False)
    self.input_reset = nn.Dense(self.token_dim, use_bias=False)
    self.state_update = nn.Dense(self.token_dim, use_bias=False)
    self.state_reset = nn.Dense(self.token_dim, use_bias=False)

  def __call__(self, inputs, state):
    update_gate = jax.nn.sigmoid(self.input_update(inputs) + self.state_update(state))
    reset_gate = jax.nn.sigmoid(self.input_reset(inputs) + self.state_reset(state))
    h = self.transformer(inputs, inputs_kv=reset_gate * self.state_layer_norm(state))
    output = (1-update_gate)*state + update_gate * h
    state = output
    return output, state

def softmax(x):
  return jax.nn.softmax(x.astype(jnp.float32), axis=-1).astype(jnp.float32)

def dot_product_attention_weights(query, key):
  query = query / jnp.sqrt(query.shape[-1])
  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
  return softmax(attn_weights)

class ImprovedMultiHeadDotProductAttention(nn.Module):
  """Multi-head dot-product attention."""

  num_heads: int
  qk_features: int = None
  v_features: int = None
  out_features: int = None

  @nn.compact
  def __call__(
      self,
      inputs_q,
      inputs_k = None,
      inputs_v = None,
      *,
      bias = None,
      mask = None,
  ):
    qk_features = self.qk_features or inputs_q.shape[-1]
    v_features = self.v_features or qk_features

    if inputs_k is None:
      inputs_k = inputs_q
    if inputs_v is None:
      inputs_v = inputs_k

    def dense(name, x, features):
      return nn.DenseGeneral(
          features=(self.num_heads, features // self.num_heads),
          kernel_init=nn.initializers.lecun_normal(),
          bias_init=nn.initializers.zeros_init(),
          use_bias=True,
          dtype=x.dtype,
          name=name,
      )(x)

    query = dense('query', inputs_q, qk_features)
    key = dense('key', inputs_k, qk_features)
    value = dense('value', inputs_v, v_features)

    # Compute attention weights.
    attn_weights = dot_product_attention_weights(query=query, key=key)

    # Return weighted sum over values for each query position.
    x = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)

    # Back to the original input dimensions.
    return nn.DenseGeneral(
        features=self.out_features or inputs_q.shape[-1],
        axis=(-2, -1),
        kernel_init=nn.initializers.lecun_normal(),
        bias_init=nn.initializers.zeros_init(),
        use_bias=True,
        dtype=x.dtype,
        name='out',
    )(x)

class CrossAttentionTransformer(nn.Module):
  """Cross attention transformer."""
  num_heads: int
  num_layers: int
  num_feats: int
  mlp_dim: int
  dtype: Any

  def setup(self):
    self.xa_blocks = [CrossAttentionBlock(
        num_heads=self.num_heads, num_feats=self.num_feats,
        mlp_dim=self.mlp_dim, dtype=self.dtype,
    ) for _ in range(self.num_layers)]
    self.output_norm = nn.LayerNorm(dtype=self.dtype)

  def __call__(self, inputs, inputs_kv):
    for i in range(self.num_layers):
      inputs = self.xa_blocks[i](inputs, inputs_kv)
    return self.output_norm(inputs)

class CrossAttentionBlock(nn.Module):
  """Cross attention block."""

  num_heads: int
  num_feats: int
  mlp_dim: int
  dtype: Any

  def setup(self):
    self.attention_norm = nn.LayerNorm(dtype=self.dtype)
    self.mlp_norm = nn.LayerNorm(dtype=self.dtype)
    self.ca_attention_norm = nn.LayerNorm(dtype=self.dtype)
    self.attention = ImprovedMultiHeadDotProductAttention(
        num_heads=self.num_heads,
        qk_features=self.num_feats,
        v_features=self.num_feats,
    )
    self.ca_attention = ImprovedMultiHeadDotProductAttention(
        num_heads=self.num_heads,
        qk_features=self.num_feats,
        v_features=self.num_feats,
    )
    self.mlp = TransformerMLP(hidden_size=self.mlp_dim)

  def __call__(self, inputs, inputs_kv):
    x = inputs
    x = x + self.ca_attention(inputs_q=self.ca_attention_norm(x), inputs_k=inputs_kv, inputs_v=inputs_kv)
    x = x + self.mlp(self.mlp_norm(x))
    x = x + self.attention(self.attention_norm(x))
    return x

class RandomStateInit(nn.Module):
  """Random, non-learnable state initialization."""

  @nn.compact
  def __call__(self, inputs, batch_shape):
    shape = inputs.shape[-2:]
    state = 0 * jax.random.normal(key=self.make_rng("default"), shape=batch_shape + shape)
    return state

class VideoSiamMAE(nn.Module):
  """Video Siamese masked autoencoder model."""

  tokenizer: nn.Module
  encoder: nn.Module
  rnn_core: nn.Module
  latent_emb_dim: int = 384

  def setup(self):
    self.cls_token = self.param('cls_token', nn.initializers.normal(stddev=0.02), (1, self.latent_emb_dim))

  @nn.compact
  def __call__(self, frame, state = None):
    frame_tokens = self.tokenizer(frame)
    frame_tokens = einops.rearrange(frame_tokens, '... h w D -> ... (h w) D')

    *b, _, _ = frame_tokens.shape
    cls_token = jnp.broadcast_to(self.cls_token, b + [1, self.cls_token.shape[-1]])
    frame_tokens = jnp.concatenate([cls_token, frame_tokens], axis=-2)

    encoded_frame_tokens = self.encoder(frame_tokens)
    if state is None:
      state = self.rnn_core.initializer(encoded_frame_tokens, batch_shape=(1,))
    features, state = self.rnn_core(encoded_frame_tokens, state)

    return dict(features=features, state=state)

model = VideoSiamMAE(
    tokenizer=Tokenizer(
        patch_embedding=PatchEmbedding(patch_size=[1, 16, 16], num_features=1024),
        posenc=SincosPosEmb(base_token_shape=[16, 16]),
    ),
    encoder=Transformer.from_variant_str(variant_str='L', dtype=jax.numpy.bfloat16),
    rnn_core=GatedTransformerCore(
        transformer=CrossAttentionTransformer(
            num_layers=4,
            num_heads=16,
            num_feats=1024,
            mlp_dim=4096,
            dtype=jax.numpy.bfloat16,
        ),
        initializer=RandomStateInit(),
        token_dim=1024,
        state_layer_norm=nn.LayerNorm(epsilon=0.0001, use_scale=True, use_bias=False),
    ),
    latent_emb_dim=1024,
)

In [None]:
# @title Feature extraction function

PATCH_SIZE = 16

def extract_features(model, params, video):
  video = video.astype(np.float32) / 255.0
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  @jax.jit
  def forward(params, frame, model_state, rng_key):
    output = model.apply(
        {'params': params},
        frame,  # [B, H, W, 3]
        model_state,  # [B, N, D]
        capture_intermediates=False,
        rngs=rng_key,
    )
    return output

  rng_key = jax.random.PRNGKey(0)
  model_state = None
  features = []
  for t in range(video.shape[0]):
    output = forward(params, video[t][None], model_state, rng_key)
    model_state, feature, cls_token = output['state'], output['features'][0, 1:, :], output['features'][0, 0:1, :]
    feature = feature.reshape(h, w, -1)
    features.append(feature)
  features = np.stack(features, axis=0)
  return features

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, restored_params, video)
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16, params=restored_params)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16, params=restored_params)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16, params=restored_params)