In [None]:
# @title Install dependency

!pip install mediapy

In [None]:
# @title Imports

import glob
import math
import random
import os

import cv2
import einops
from PIL import Image
import scipy.io as sio
import mediapy as media
import numpy as np
import queue
import seaborn as sns
import tensorflow as tf
import torch
import tqdm
import jax
import jax.numpy as jnp
import sklearn

In [None]:
# @title Download checkpoint

%mkdir /content/rvm
%cd /content/rvm

!wget https://storage.googleapis.com/representations4d/checkpoints/pretrain_rvm_large16_256_175558463.npz

### Prepare Qualitative Examples

In [None]:
# @title Download DAVIS videos

video_names = [
    'goat',
    # 'india',
    # 'soapbox',
]

for video_name in video_names:
  !wget https://storage.googleapis.com/representations4d/datasets/davis2017/{video_name}.mp4

davis_videos = []
for video_name in video_names:
  video = media.read_video(f"{video_name}.mp4")
  video = media.resize_video(video, (480, 880))
  davis_videos.append(video)
  media.show_video(video, fps=16, height=128, codec='gif')

In [None]:
# @title Download MoCA videos

video_names = [
    'scorpionfish_0',
    # 'snow_leopard_8',
    # 'white_tailed_ptarmigan',
    # 'grasshopper_1',
    # 'plaice',
    # 'pygmy_seahorse_0',
    # 'stick_insect_0',
    # 'crab',
    # 'crab_1',
    # 'copperhead_snake',
]

for video_name in video_names:
  !wget https://storage.googleapis.com/representations4d/datasets/MoCA/{video_name}.mp4

moca_videos = []
for video_name in video_names:
  video = media.read_video(f"{video_name}.mp4")
  video = media.resize_video(video, (360, 640))
  moca_videos.append(video)
  media.show_video(video, fps=16, width=256, codec='gif')

In [None]:
# @title Download Dalmatian videos

!wget https://storage.googleapis.com/representations4d/assets/dalmatian_illusion.mp4

dalmatian_video = media.read_video("dalmatian_illusion.mp4")
dalmatian_video = media.resize_video(dalmatian_video, (360, 480))
media.show_video(dalmatian_video, fps=16, width=256, codec='gif')

In [None]:
# @title Create noise video

def create_noise_movie(T, h, w):
  background = np.random.rand(h, w, 3)
  square = np.random.rand(h // 2, w // 2, 3)
  movie = []
  for t in range(T):
    frame = background.copy()
    # Calculate the position of the square
    x = (t * (w - w // 2)) // T  # Move across the width
    y = (h - h // 2) // 2  # Centered vertically
    frame[y:y + h // 2, x:x + w // 2] = square
    movie.append(frame)
  return (np.array(movie) * 255.0).astype(np.uint8)

noise_video = create_noise_movie(64, 512, 512)
media.show_video(noise_video, fps=16, height=128, codec='gif')

In [None]:
# @title Combine qualitative videos

videos = davis_videos + moca_videos + [dalmatian_video, noise_video]

In [None]:
# @title Visualization functions

def visualize_pca(features, video):
  pca = sklearn.decomposition.PCA(n_components=3, whiten=True)
  pca_data = pca.fit_transform(einops.rearrange(features, 't n m c -> (t n m) c'))
  pca_data = pca_data.reshape(features.shape[:-1] + (3,))
  pca_video = (pca_data - pca_data.min()) / (pca_data.max() - pca_data.min())
  pca_video = jax.image.resize(pca_video, video.shape, method='nearest')
  return pca_video

def segmentations_to_video(masks):
  num_objects = np.max(masks)  # assume consecutive numbering
  # palette = [(0, 0, 0)] + sns.color_palette(n_colors=num_objects)
  palette = sns.color_palette(n_colors=num_objects + 1)
  video = np.zeros((masks.shape[0], masks.shape[1], masks.shape[2], 3))
  for i in range(num_objects + 1):
    video[masks == i] = palette[i]
  return video

def visualize_kmeans(features, video, n_clusters=5):
  kmeans = sklearn.cluster.KMeans(n_clusters, init='k-means++')
  result = kmeans.fit(einops.rearrange(features, 't n m c -> (t n m) c'))
  kmeans_labels = jnp.reshape(result.labels_, features.shape[:-1])
  kmeans_video = segmentations_to_video(kmeans_labels)
  kmeans_video = jax.image.resize(kmeans_video, video.shape, method='nearest')
  return kmeans_video

### Flax Model

In [None]:
# @title Define model {form-width: "20%"}

import dataclasses
import re
from typing import Any
import jax.numpy as jnp
from flax import linen as nn

class PatchEmbedding(nn.Module):
  """Extracts patches with a single learned linear projection."""
  patch_size: list[int]
  num_features: int

  @nn.compact
  def __call__(self, images):
    return nn.Conv(features=self.num_features, kernel_size=self.patch_size, strides=self.patch_size, padding='VALID')(images)

def get_mae_sinusoid_encoding_table(n_position, d_hid, dtype=jnp.float32):
  """Sinusoid positional encoding table for MAE."""
  def get_position_angle_vec(position):
    return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

  return jnp.asarray(sinusoid_table, dtype)[None, ...]

class SincosPosEmb(nn.Module):
  """Returns sinusoidal positional embedding given the shape of the tokens."""
  base_token_shape: list[int] | None = None

  @nn.compact
  def __call__(self, tokens_shape):
    d = tokens_shape[-1]
    if self.base_token_shape is not None:
      h, w = self.base_token_shape
    else:
      h, w = tokens_shape[-3], tokens_shape[-2]

    posenc = get_mae_sinusoid_encoding_table(np.prod((h, w)), d)
    posenc = einops.rearrange(posenc, '... (h w) d -> ... h w d', h=h, w=w)
    *b, tokens_h, tokens_w, _ = tokens_shape
    for _ in range(len(b)-1):
      posenc = jnp.expand_dims(posenc, axis=0)
    if tokens_h != h or tokens_w != w:
      posenc = jax.image.resize(posenc, (*b, tokens_h, tokens_w, d), method='bicubic')

    return posenc

class Tokenizer(nn.Module):
  """Simple tokenizer."""
  patch_embedding: nn.Module
  posenc: nn.Module

  @nn.compact
  def __call__(self, image):
    tokens = self.patch_embedding(image)
    posenc = self.posenc(tokens.shape)
    tokens += posenc
    return tokens

class TransformerMLP(nn.Module):
  """Simple MLP with a single hidden layer for use in Transformer blocks."""
  hidden_size: int = None  # Defaults to 4 times input dims.

  @nn.compact
  def __call__(self, inputs):
    d = inputs.shape[-1]
    hidden_size = 4 * d if self.hidden_size is None else self.hidden_size
    h = nn.Dense(
        features=hidden_size,
        name='dense_in',
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.zeros,
        dtype=inputs.dtype,
    )(inputs)
    h = nn.gelu(h)
    return nn.Dense(
        features=inputs.shape[-1],
        name='dense_out',
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.zeros,
        dtype=h.dtype,
    )(h)

class PreNormBlock(nn.Module):
  """Pre-LN Transformer layer (default transformer layer)."""

  attention: Any
  mlp: nn.Module
  attention_norm: Any
  mlp_norm: Any

  @nn.compact
  def __call__(self, tokens):
    norm_tokens = self.attention_norm(tokens)
    tokens += self.attention(
        inputs_q=norm_tokens,
        inputs_k=norm_tokens,
        inputs_v=norm_tokens,
    )
    norm_tokens = self.mlp_norm(tokens)
    return tokens + self.mlp(norm_tokens)

VIT_SIZES = {
    'mu': (32, 1, 128, 2),
    'Ti': (192, 12, 768, 3),
    'S': (384, 12, 1536, 6),
    'M': (512, 12, 2048, 8),
    'B': (768, 12, 3072, 12),
    'L': (1024, 24, 4096, 16),
    'H': (1280, 32, 5120, 16),
    'g': (1408, 40, 6144, 16),
    'G': (1664, 48, 8192, 16),
    'e': (1792, 56, 15360, 16),
}

@dataclasses.dataclass(frozen=True)
class ViTSpec:
  """Spec for the size of a Vision Transformer."""

  hidden_size: int  # Dimension of tokens passed between blocks.
  num_layers: int  # Number of trasformer blocks.
  mlp_size: int  # Hidden dimension of the MLP in each block.
  num_heads: int  # Number of attention heads.
  patch_size: int = None  # Patch size of initial image patches.

  @classmethod
  def from_variant_string(cls, variant_str: str):
    """Parse variant strings like "ViT-L", "B", or "Ti/16"."""
    r = re.match(
        r'^([Vv][Ii][Tt][-_])?(?P<name>[a-zA-Z]{1,2})(/(?P<patch>\d+))?$',
        variant_str,
    )
    if r is None:
      raise ValueError(f'Invalid variant string: {variant_str!r}.')
    name = r.groupdict()['name']
    spec = cls(*VIT_SIZES[name])

    patch_size = r.groupdict()['patch']
    if patch_size is not None:
      spec = dataclasses.replace(spec, patch_size=int(patch_size))
    return spec

  @property
  def kwargs(self):
    kwargs = dict(
        hidden_size=self.hidden_size,
        num_layers=self.num_layers,
        mlp_size=self.mlp_size,
        num_heads=self.num_heads,
        patch_size=self.patch_size,
    )
    if self.patch_size is None:
      del kwargs['patch_size']
    return kwargs

class Transformer(nn.Module):
  """Simple transformer model."""
  layers: tuple[Any]

  @nn.compact
  def __call__(self, tokens):
    for layer in self.layers:
      tokens = layer(tokens)
    tokens = nn.LayerNorm(dtype=tokens.dtype)(tokens)
    return tokens

  @classmethod
  def from_variant_str(cls, variant_str: str, **kwargs):
    vit_spec = ViTSpec.from_variant_string(variant_str)
    all_kwargs = vit_spec.kwargs | kwargs
    all_kwargs.pop('patch_size', None)
    all_kwargs.pop('hidden_size', None)
    return cls.from_spec(**all_kwargs)

  @classmethod
  def from_spec(
      cls,
      num_heads: int,
      num_layers: int,
      mlp_size = None,
      dtype=jnp.float32,
      qk_features = None,
      v_features = None,
      **kwargs,
  ):
    return cls(
        layers=tuple(
            PreNormBlock(
                attention_norm=nn.LayerNorm(dtype=dtype),
                mlp_norm=nn.LayerNorm(dtype=dtype),
                attention=ImprovedMultiHeadDotProductAttention(
                    num_heads=num_heads,
                    qk_features=qk_features,
                    v_features=v_features,
                ),
                mlp=TransformerMLP(hidden_size=mlp_size),
            )
            for _ in range(num_layers)
        ),
        **kwargs,
    )

class GatedTransformerCore(nn.Module):
  transformer: nn.Module
  initializer: nn.Module
  token_dim: int
  state_layer_norm: nn.Module

  def setup(self):
    self.input_update = nn.Dense(self.token_dim, use_bias=False)
    self.input_reset = nn.Dense(self.token_dim, use_bias=False)
    self.state_update = nn.Dense(self.token_dim, use_bias=False)
    self.state_reset = nn.Dense(self.token_dim, use_bias=False)

  def __call__(self, inputs, state):
    update_gate = jax.nn.sigmoid(self.input_update(inputs) + self.state_update(state))
    reset_gate = jax.nn.sigmoid(self.input_reset(inputs) + self.state_reset(state))
    h = self.transformer(inputs, inputs_kv=reset_gate * self.state_layer_norm(state))
    output = (1-update_gate)*state + update_gate * h
    state = output
    return output, state

def softmax(x):
  return jax.nn.softmax(x.astype(jnp.float32), axis=-1).astype(jnp.float32)

def dot_product_attention_weights(query, key):
  query = query / jnp.sqrt(query.shape[-1])
  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
  return softmax(attn_weights)

class ImprovedMultiHeadDotProductAttention(nn.Module):
  """Multi-head dot-product attention."""

  num_heads: int
  qk_features: int = None
  v_features: int = None
  out_features: int = None

  @nn.compact
  def __call__(
      self,
      inputs_q,
      inputs_k = None,
      inputs_v = None,
      *,
      bias = None,
      mask = None,
  ):
    qk_features = self.qk_features or inputs_q.shape[-1]
    v_features = self.v_features or qk_features

    if inputs_k is None:
      inputs_k = inputs_q
    if inputs_v is None:
      inputs_v = inputs_k

    def dense(name, x, features):
      return nn.DenseGeneral(
          features=(self.num_heads, features // self.num_heads),
          kernel_init=nn.initializers.lecun_normal(),
          bias_init=nn.initializers.zeros_init(),
          use_bias=True,
          dtype=x.dtype,
          name=name,
      )(x)

    query = dense('query', inputs_q, qk_features)
    key = dense('key', inputs_k, qk_features)
    value = dense('value', inputs_v, v_features)

    # Compute attention weights.
    attn_weights = dot_product_attention_weights(query=query, key=key)

    # Return weighted sum over values for each query position.
    x = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)

    # Back to the original input dimensions.
    return nn.DenseGeneral(
        features=self.out_features or inputs_q.shape[-1],
        axis=(-2, -1),
        kernel_init=nn.initializers.lecun_normal(),
        bias_init=nn.initializers.zeros_init(),
        use_bias=True,
        dtype=x.dtype,
        name='out',
    )(x)

class CrossAttentionTransformer(nn.Module):
  """Cross attention transformer."""
  num_heads: int
  num_layers: int
  num_feats: int
  mlp_dim: int
  dtype: Any

  def setup(self):
    self.xa_blocks = [CrossAttentionBlock(
        num_heads=self.num_heads, num_feats=self.num_feats,
        mlp_dim=self.mlp_dim, dtype=self.dtype,
    ) for _ in range(self.num_layers)]
    self.output_norm = nn.LayerNorm(dtype=self.dtype)

  def __call__(self, inputs, inputs_kv):
    for i in range(self.num_layers):
      inputs = self.xa_blocks[i](inputs, inputs_kv)
    return self.output_norm(inputs)

class CrossAttentionBlock(nn.Module):
  """Cross attention block."""

  num_heads: int
  num_feats: int
  mlp_dim: int
  dtype: Any

  def setup(self):
    self.attention_norm = nn.LayerNorm(dtype=self.dtype)
    self.mlp_norm = nn.LayerNorm(dtype=self.dtype)
    self.ca_attention_norm = nn.LayerNorm(dtype=self.dtype)
    self.attention = ImprovedMultiHeadDotProductAttention(
        num_heads=self.num_heads,
        qk_features=self.num_feats,
        v_features=self.num_feats,
    )
    self.ca_attention = ImprovedMultiHeadDotProductAttention(
        num_heads=self.num_heads,
        qk_features=self.num_feats,
        v_features=self.num_feats,
    )
    self.mlp = TransformerMLP(hidden_size=self.mlp_dim)

  def __call__(self, inputs, inputs_kv):
    x = inputs
    x = x + self.ca_attention(inputs_q=self.ca_attention_norm(x), inputs_k=inputs_kv, inputs_v=inputs_kv)
    x = x + self.mlp(self.mlp_norm(x))
    x = x + self.attention(self.attention_norm(x))
    return x

class RandomStateInit(nn.Module):
  """Random, non-learnable state initialization."""

  @nn.compact
  def __call__(self, inputs, batch_shape):
    shape = inputs.shape[-2:]
    state = 0 * jax.random.normal(key=self.make_rng("default"), shape=batch_shape + shape)
    return state

class VideoSiamMAE(nn.Module):
  """Video Siamese masked autoencoder model."""

  tokenizer: nn.Module
  encoder: nn.Module
  rnn_core: nn.Module
  latent_emb_dim: int = 384

  def setup(self):
    self.cls_token = self.param('cls_token', nn.initializers.normal(stddev=0.02), (1, self.latent_emb_dim))

  @nn.compact
  def __call__(self, frame, state = None):
    frame_tokens = self.tokenizer(frame)
    frame_tokens = einops.rearrange(frame_tokens, '... h w D -> ... (h w) D')

    *b, _, _ = frame_tokens.shape
    cls_token = jnp.broadcast_to(self.cls_token, b + [1, self.cls_token.shape[-1]])
    frame_tokens = jnp.concatenate([cls_token, frame_tokens], axis=-2)

    encoded_frame_tokens = self.encoder(frame_tokens)
    if state is None:
      state = self.rnn_core.initializer(encoded_frame_tokens, batch_shape=(1,))
    features, state = self.rnn_core(encoded_frame_tokens, state)

    return dict(features=features, state=state)

model = VideoSiamMAE(
    tokenizer=Tokenizer(
        patch_embedding=PatchEmbedding(patch_size=[1, 16, 16], num_features=1024),
        posenc=SincosPosEmb(base_token_shape=[16, 16]),
    ),
    encoder=Transformer.from_variant_str(variant_str='L', dtype=jax.numpy.bfloat16),
    rnn_core=GatedTransformerCore(
        transformer=CrossAttentionTransformer(
            num_layers=4,
            num_heads=16,
            num_feats=1024,
            mlp_dim=4096,
            dtype=jax.numpy.bfloat16,
        ),
        initializer=RandomStateInit(),
        token_dim=1024,
        state_layer_norm=nn.LayerNorm(epsilon=0.0001, use_scale=True, use_bias=False),
    ),
    latent_emb_dim=1024,
)

In [None]:
# @title Load checkpoint

%cd /content/rvm

def recover_tree(flat_dict):
  tree = {}
  for k, v in flat_dict.items():
    parts = k.split("/")
    node = tree
    for part in parts[:-1]:
      if part not in node:
        node[part] = {}
      node = node[part]
    node[parts[-1]] = v
  return tree

restored_params = recover_tree(np.load("pretrain_rvm_large16_256_175558463.npz", allow_pickle=False))

In [None]:
# @title Feature extraction function

PATCH_SIZE = 16

def extract_features(model, params, video):
  video = video.astype(np.float32) / 255.0
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  @jax.jit
  def forward(params, frame, model_state, rng_key):
    output = model.apply(
        {'params': params},
        frame,  # [B, H, W, 3]
        model_state,  # [B, N, D]
        capture_intermediates=False,
        rngs=rng_key,
    )
    return output

  rng_key = jax.random.PRNGKey(0)
  model_state = None
  features = []
  for t in range(video.shape[0]):
    output = forward(params, video[t][None], model_state, rng_key)
    model_state, feature, cls_token = output['state'], output['features'][0, 1:, :], output['features'][0, 0:1, :]
    feature = feature.reshape(h, w, -1)
    features.append(feature)
  features = np.stack(features, axis=0)
  return features

In [None]:
# @title Label propagation functions (jax)

def draw_labelmap_np(img, pt, sigma=0.5):
  # Draw a 2D gaussian
  # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py

  # Check that any part of the gaussian is in-bounds
  ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
  br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
  if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0):
    # If not, just return the image as is
    return img

  # Generate gaussian
  size = 6 * sigma + 1
  x = np.arange(0, size, 1, float)
  y = x[:, np.newaxis]
  x0 = y0 = size // 2
  g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))

  # Usable gaussian range
  g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
  g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
  # Image range
  img_x = max(0, ul[0]), min(br[0], img.shape[1])
  img_y = max(0, ul[1]), min(br[1], img.shape[0])

  img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
  return img

def mask2heatmap(mask, h, w, height, width):
  """Convert segmentation mask to heatmap (resize and one-hot encode)"""
  label_map = np.unique(mask)
  mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
  heatmap = np.stack([mask == l for l in label_map], axis=-1).astype(np.float32)
  heatmap = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR)
  return heatmap, label_map

def pose2heatmap(pose, h, w, height, width):
  """Convert pose to heatmap (resize and one-hot encode)"""
  n_class = pose.shape[0]
  coord = pose * np.array([w / width, h / height])
  heatmap = np.zeros((h, w, n_class + 1), dtype=np.float32)
  for i in range(n_class):
    heatmap[..., i + 1] = draw_labelmap_np(np.zeros((h, w)), coord[i])
  heatmap[..., 0] = heatmap.sum(axis=-1) == 0
  return heatmap

def heatmap2pose(preds, h, w, ori_height, ori_width, topk=5):
  """Convert heatmap to pose (argmax and resize)."""
  current_coords, jnt_visibles = [], []
  for t in range(preds.shape[0]):
    pred = preds[t][..., 1:]  # [h, w, n_class]
    flatlbls = pred.reshape(-1, pred.shape[-1])  # [h * w, n_class]
    # Get top k values and indices for each class
    vals, ids = jax.lax.top_k(flatlbls.T, k=topk) # [n_class, k]
    # Normalize values
    vals = vals / vals.sum(1, keepdims=True)
    # Calculate coordinates
    xx, yy = ids % pred.shape[1], ids // pred.shape[1]
    current_coord = np.stack([(xx * vals).sum(1), (yy * vals).sum(1)], axis=1) # [n_class, 2]
    # Resize coordinates to original size
    current_coord[:, 0] = current_coord[:, 0] / w * ori_width
    current_coord[:, 1] = current_coord[:, 1] / h * ori_height
    # Set invisible joints to -1
    current_coord[flatlbls.sum(0) == 0] = -1
    # Get visibility flag
    jnt_visible = (current_coord[:, 0] >= 0).astype(float)
    current_coords.append(current_coord)
    jnt_visibles.append(jnt_visible)
  current_coords = np.stack(current_coords, axis=0)
  jnt_visibles = np.stack(jnt_visibles, axis=0)
  return current_coords, jnt_visibles

def heatmap2mask(preds, lbl_map, height, width, ori_height, ori_width):
  """Convert heatmap to segmentation mask (argmax and resize)."""
  pred_lbls = []
  for t in range(preds.shape[0]):
    pred = np.array(preds[t])
    # Upsample predicted soft label maps
    pred = cv2.resize(pred, (width, height))
    # Argmax to get the hard label for index
    pred_lbl = np.argmax(pred, axis=-1)
    pred_lbl = np.array(lbl_map, dtype=np.int32)[pred_lbl]
    pred_lbl = cv2.resize(pred_lbl, (ori_width, ori_height), interpolation=cv2.INTER_NEAREST_EXACT)
    pred_lbls.append(pred_lbl)
  pred_lbls = np.stack(pred_lbls, axis=0)
  return pred_lbls

def label_propagation(feats, heatmap, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False):
  """Propagation of the heatmap based on feature similarity."""

  # Creates a mask indicating valid neighbors for each grid element.
  gx, gy = np.meshgrid(np.arange(0, h), np.arange(0, w), indexing="ij")  # (h, w)
  D = (gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2
  D = D.astype(np.float32) ** 0.5
  D = (D < radius).astype(np.float32)  # (h, w, h, w)
  D[D == 0] = -1e10
  D[D == 1] = 0
  D = D.transpose(2, 3, 0, 1)  # (h, w, h, w)

  # The queue stores the context frames
  que = queue.Queue(n_context)
  for _ in range(n_context):
    que.put([feats[0], heatmap])

  preds = []
  for t in tqdm.tqdm(range(feats.shape[0])):
    # Use first and previous frames as context
    ctx_feats = jnp.stack([feats[0]] + [pair[0] for pair in que.queue])
    ctx_lbls = jnp.stack([heatmap] + [pair[1] for pair in que.queue])

    aff = jnp.einsum('hwc, tmnc -> hwtmn', feats[t], ctx_feats) / temperature  # (h, w, n_context+1, h, w)
    if restrict_neighborhood:
      # aff[:, :, 1:] += D[:, :, None]  # (h, w, n_context+1, h, w)
      aff.at[:, :, 1:].add(D[:, :, None])  # (h, w, n_context+1, h, w)
    aff = aff.reshape(aff.shape[0], aff.shape[1], -1)  # (h, w, n_context+1 * h * w)

    weights, ids = jax.lax.top_k(aff, topk)  # (h, w, topk), (h, w, topk)
    weights = jax.nn.softmax(weights, axis=-1)  # (h, w, topk)
    ctx_lbls = ctx_lbls.reshape(-1, ctx_lbls.shape[-1])  # (n_context+1 * h * w, n_class)
    pred = jnp.einsum('hwlk, hwl -> hwk', ctx_lbls[ids], weights) # (h, w, n_class)

    if que.qsize() == n_context:
      que.get()
    que.put([feats[t], pred])

    if norm_mask:
      pred -= pred.min(-1)[0][..., None]
      pred /= pred.max(-1)[0][..., None]

    preds.append(pred)
  preds = jnp.stack(preds)
  return preds

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

for video in videos:
  features = extract_features(model, restored_params, video)
  print(features.shape)

  pca_video = visualize_pca(features, video)
  kmeans_video = visualize_kmeans(features, video)
  mixed_video = 0.5 * video / 255.0 + 0.5 * kmeans_video
  video_titles = ['pca', 'kmeans', 'mixed']
  media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

### PyTorch Model

In [None]:
# @title Define model {form-width: "20%"}

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import einops
import re
import dataclasses

class PatchEmbedding(nn.Module):
  def __init__(self, patch_size=(16, 16), num_features=1024):
    super().__init__()
    self.patch_size = patch_size
    self.num_features = num_features
    self.Conv_0 = nn.Conv2d(in_channels=3, out_channels=num_features, kernel_size=patch_size, stride=patch_size, padding=0)

  def forward(self, x):
    x = x.permute(0, 3, 1, 2)
    return self.Conv_0(x).permute(0, 2, 3, 1)

def get_mae_sinusoid_encoding_table(n_position, d_hid, dtype=torch.float32):
  """Sinusoid positional encoding table for MAE."""
  def get_position_angle_vec(position):
    return [position / math.pow(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

  sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

  return torch.tensor(sinusoid_table, dtype=dtype)[None, ...]

class SincosPosEmb(nn.Module):
  """Returns sinusoidal positional embedding given the shape of the tokens."""
  def __init__(self, base_token_shape=None, use_jax_interpolation=False):
    super().__init__()
    self.base_token_shape = base_token_shape
    self.use_jax_interpolation = use_jax_interpolation

  def forward(self, tokens_shape):
    d = tokens_shape[-1]
    if self.base_token_shape is not None:
      h, w = self.base_token_shape
    else:
      h, w = tokens_shape[-3], tokens_shape[-2]

    posenc = get_mae_sinusoid_encoding_table(h * w, d)  # [1, h*w, d]
    posenc = posenc.view(1, h, w, d)  # [1, h, w, d]

    *b, tokens_h, tokens_w, _ = tokens_shape
    for _ in range(len(b)-1):
      posenc = posenc.expand(*b, -1, -1, -1)

    if tokens_h != h or tokens_w != w:
      if self.use_jax_interpolation:
        posenc = jnp.array(posenc.numpy())
        posenc = jax.image.resize(posenc, (*b, tokens_h, tokens_w, d), method='bicubic')
        posenc = torch.from_numpy(np.array(posenc))
      else:
        posenc = posenc.view(-1, h, w, d)
        posenc = F.interpolate(
          posenc.permute(0, 3, 1, 2),  # [B, D, H, W]
          size=(tokens_h, tokens_w),
          mode='bicubic',
          align_corners=False
        ).permute(0, 2, 3, 1)  # [B, H, W, D]
        posenc = posenc.view(*b, tokens_h, tokens_w, d)

    return posenc.cuda()

class Tokenizer(nn.Module):
  def __init__(self, patch_embedding, posenc):
    super().__init__()
    self.patch_embedding = patch_embedding
    self.posenc = posenc

  def forward(self, x):
    tokens = self.patch_embedding(x)
    # posenc = self.posenc(tokens.shape)
    # tokens += posenc
    return tokens

class TransformerMLP(nn.Module):
  """Simple MLP with a single hidden layer for use in Transformer blocks."""
  def __init__(self, input_dim, hidden_size=None):
    super().__init__()
    self.hidden_size = 4 * input_dim if hidden_size is None else hidden_size
    self.dense_in = nn.Linear(input_dim, self.hidden_size)
    self.dense_out = nn.Linear(self.hidden_size, input_dim)
    nn.init.xavier_uniform_(self.dense_in.weight)
    nn.init.zeros_(self.dense_in.bias)
    nn.init.xavier_uniform_(self.dense_out.weight)
    nn.init.zeros_(self.dense_out.bias)

  def forward(self, x):
    h = F.gelu(self.dense_in(x))
    return self.dense_out(h)

def dot_product_attention_weights(query, key):
  query = query / math.sqrt(query.size(-1))
  attn_weights = torch.einsum('bqhd,bkhd->bhqk', query, key)
  attn_weights = F.softmax(attn_weights, dim=-1)
  return attn_weights

class ImprovedMultiHeadDotProductAttention(nn.Module):
  def __init__(self, embed_dim, num_heads, qk_features=None, v_features=None, out_features=None):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.qk_features = qk_features or embed_dim
    self.v_features = v_features or self.qk_features
    self.out_features = out_features or embed_dim

    # Head dimensions
    self.head_dim_qk = self.qk_features // self.num_heads
    self.head_dim_v = self.v_features // self.num_heads

    # Linear projections
    self.query = nn.Linear(embed_dim, self.qk_features)
    self.key  = nn.Linear(embed_dim, self.qk_features)
    self.value = nn.Linear(embed_dim, self.v_features)

    # Output projection
    self.out = nn.Linear(self.v_features, self.out_features)

  def forward(self, inputs_q, inputs_k=None, inputs_v=None, mask=None):
    batch_size, seq_len_q, _ = inputs_q.shape
    if inputs_k is None:
      inputs_k = inputs_q
    if inputs_v is None:
      inputs_v = inputs_k

    seq_len_k = inputs_k.shape[1]

    # Linear projections and reshape to (batch, seq_len, num_heads, head_dim)
    query = self.query(inputs_q).view(batch_size, seq_len_q, self.num_heads, self.head_dim_qk)
    key   = self.key(inputs_k).view(batch_size, seq_len_k, self.num_heads, self.head_dim_qk)
    value = self.value(inputs_v).view(batch_size, seq_len_k, self.num_heads, self.head_dim_v)

    # Scaled dot-product attention
    query_scaled = query / math.sqrt(self.head_dim_qk)
    attn_weights = torch.einsum('bqhd,bkhd->bhqk', query_scaled, key)
    if mask is not None:
      attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
    attn_weights = F.softmax(attn_weights, dim=-1)

    # Weighted sum over values
    x = torch.einsum('bhqk,bkhd->bqhd', attn_weights, value)
    x = x.reshape(batch_size, seq_len_q, self.num_heads * self.head_dim_v)

    # Output projection
    out = self.out(x)
    return out

class PreNormBlock(nn.Module):
  def __init__(self, attention_norm, mlp_norm, attention, mlp):
    super().__init__()
    self.attention_norm = attention_norm
    self.mlp_norm = mlp_norm
    self.attention = attention
    self.mlp = mlp

  def forward(self, x):
    norm_x = self.attention_norm(x)
    x = x + self.attention(norm_x)
    norm_x = self.mlp_norm(x)
    x = x + self.mlp(norm_x)
    return x

VIT_SIZES = {
    'mu': (32, 1, 128, 2),
    'Ti': (192, 12, 768, 3),
    'S': (384, 12, 1536, 6),
    'M': (512, 12, 2048, 8),
    'B': (768, 12, 3072, 12),
    'L': (1024, 24, 4096, 16),
    'H': (1280, 32, 5120, 16),
    'g': (1408, 40, 6144, 16),
    'G': (1664, 48, 8192, 16),
    'e': (1792, 56, 15360, 16),
}

@dataclasses.dataclass(frozen=True)
class ViTSpec:
  hidden_size: int
  num_layers: int
  mlp_size: int
  num_heads: int
  patch_size: int = None

  @classmethod
  def from_variant_string(cls, variant_str: str):
    r = re.match(r'^([Vv][Ii][Tt][-_])?(?P<name>[a-zA-Z]{1,2})(/(?P<patch>\d+))?$', variant_str)
    if r is None:
      raise ValueError(f'Invalid variant string: {variant_str!r}.')
    name = r.groupdict()['name']
    spec = cls(*VIT_SIZES[name])
    patch_size = r.groupdict()['patch']
    if patch_size is not None:
      spec = dataclasses.replace(spec, patch_size=int(patch_size))
    return spec

class Transformer(nn.Module):
  def __init__(self, num_layers, hidden_size, num_heads, mlp_size, qk_features=None, v_features=None):
    super().__init__()
    self.layers = nn.ModuleList([
        PreNormBlock(
            attention_norm=nn.LayerNorm(hidden_size, eps=1e-06, dtype=torch.float32),
            mlp_norm=nn.LayerNorm(hidden_size, eps=1e-06, dtype=torch.float32),
            attention=ImprovedMultiHeadDotProductAttention(
                embed_dim=hidden_size,
                num_heads=num_heads,
                qk_features=qk_features or hidden_size,
                v_features=v_features or hidden_size,
            ),
            mlp=TransformerMLP(input_dim=hidden_size, hidden_size=mlp_size),
        )
        for _ in range(num_layers)
    ])
    self.LayerNorm_0 = nn.LayerNorm(hidden_size)

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return self.LayerNorm_0(x)

  @classmethod
  def from_variant_str(cls, variant_str: str, **kwargs):
    spec = ViTSpec.from_variant_string(variant_str)
    all_kwargs = dict(
      num_layers=spec.num_layers,
      hidden_size=spec.hidden_size,
      mlp_size=spec.mlp_size,
      num_heads=spec.num_heads,
    )
    all_kwargs.update(kwargs)
    return cls(**all_kwargs)

class CrossAttentionBlock(nn.Module):
  def __init__(self, num_heads, num_feats, mlp_dim, dtype=torch.float32):
    super().__init__()
    self.attention_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)
    self.mlp_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)
    self.ca_attention_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)

    self.attention = ImprovedMultiHeadDotProductAttention(
      embed_dim=num_feats, num_heads=num_heads
    )
    self.ca_attention = ImprovedMultiHeadDotProductAttention(
      embed_dim=num_feats, num_heads=num_heads
    )
    self.mlp = TransformerMLP(input_dim=num_feats, hidden_size=mlp_dim)

  def forward(self, x, x_kv):
    residual = x
    x = x + self.ca_attention(inputs_q=self.ca_attention_norm(x), inputs_k=x_kv, inputs_v=x_kv)
    x = x + self.mlp(self.mlp_norm(x))
    x = x + self.attention(self.attention_norm(x))
    return x

class CrossAttentionTransformer(nn.Module):
  def __init__(self, num_layers, num_heads, num_feats, mlp_dim, dtype=torch.float32):
    super().__init__()
    self.xa_blocks = nn.ModuleList([
      CrossAttentionBlock(num_heads, num_feats, mlp_dim, dtype=dtype)
      for _ in range(num_layers)
    ])
    self.output_norm = nn.LayerNorm(num_feats, eps=1e-6, dtype=dtype)

  def forward(self, inputs, inputs_kv):
    x = inputs
    for block in self.xa_blocks:
      x = block(x, inputs_kv)
    return self.output_norm(x)

class RandomStateInit(nn.Module):
  """Random, non-learnable state initialization."""

  def __init__(self):
    super().__init__()

  def forward(self, inputs, batch_shape):
    shape = inputs.shape[-2:]
    state = 0 * torch.randn(batch_shape + shape, dtype=inputs.dtype, device=inputs.device)
    return state

class GatedTransformerCore(nn.Module):
  def __init__(self, transformer, initializer, token_dim, state_layer_norm):
    super().__init__()
    self.transformer = transformer
    self.initializer = initializer
    self.token_dim = token_dim
    self.state_layer_norm = state_layer_norm

    self.input_update = nn.Linear(token_dim, token_dim, bias=False)
    self.input_reset = nn.Linear(token_dim, token_dim, bias=False)
    self.state_update = nn.Linear(token_dim, token_dim, bias=False)
    self.state_reset = nn.Linear(token_dim, token_dim, bias=False)

  def forward(self, inputs, state):
    update_gate = F.sigmoid(self.input_update(inputs) + self.state_update(state))
    reset_gate = F.sigmoid(self.input_reset(inputs) + self.state_reset(state))
    h = self.transformer(inputs, inputs_kv=reset_gate * self.state_layer_norm(state))
    output = (1 - update_gate) * state + update_gate * h
    state = output
    return output, state

class VideoSiamMAE(nn.Module):
  """Video Siamese masked autoencoder model."""

  def __init__(self, tokenizer, encoder, rnn_core, latent_emb_dim=384):
    super().__init__()
    self.tokenizer = tokenizer
    self.encoder = encoder
    self.rnn_core = rnn_core
    self.latent_emb_dim = latent_emb_dim

    # cls_token is a learnable parameter
    self.cls_token = nn.Parameter(torch.randn(1, 1, latent_emb_dim) * 0.02)

  def forward(self, frame, state=None):
    # Tokenize input frame
    frame_tokens = self.tokenizer(frame)  # shape [..., h, w, D] expected
    frame_tokens = einops.rearrange(frame_tokens, '... h w d -> ... (h w) d')

    *b, _, _ = frame_tokens.shape
    # Broadcast cls_token across batch
    cls_token = self.cls_token.expand(*b, -1, -1)  # shape [..., 1, D]

    # Concatenate CLS with patch tokens
    frame_tokens = torch.cat([cls_token, frame_tokens], dim=-2)

    # Encode with transformer encoder
    encoded_frame_tokens = self.encoder(frame_tokens)

    # Initialize state if first step
    if state is None:
        # Expect initializer to accept (inputs, batch_shape)
        state = self.rnn_core.initializer(encoded_frame_tokens, batch_shape=(1,))

    # Recurrent core update
    features, state = self.rnn_core(encoded_frame_tokens, state)

    return features, state

model = VideoSiamMAE(
    tokenizer=Tokenizer(
        patch_embedding=PatchEmbedding(patch_size=[16, 16], num_features=1024),
        posenc=SincosPosEmb(base_token_shape=[16, 16]),
    ),
    encoder=Transformer.from_variant_str(variant_str='L'),
    rnn_core=GatedTransformerCore(
        transformer=CrossAttentionTransformer(
            num_layers=4,
            num_heads=16,
            num_feats=1024,
            mlp_dim=4096,
            dtype=torch.float32,
        ),
        initializer=RandomStateInit(),
        token_dim=1024,
        state_layer_norm=nn.LayerNorm(1024, eps=0.0001, bias=False),
    ),
    latent_emb_dim=1024,
)
model = model.cuda()
model = model.eval()
torch.set_grad_enabled(False)

In [None]:
# @title Load checkpoint

%cd /content/rvm

def recover_tree(flat_dict):
  tree = {}
  for k, v in flat_dict.items():
    parts = k.split("/")
    node = tree
    for part in parts[:-1]:
      if part not in node:
        node[part] = {}
      node = node[part]
    node[parts[-1]] = v
  return tree

def flatten_flax_params(params, parent_key=""):
  """
  Flatten nested Flax params dict into {'a.b.c': subdict}.
  """
  items = {}
  for k, v in params.items():
    new_key = f"{parent_key}.{k}" if parent_key else k
    if isinstance(v, dict):
      items.update(flatten_flax_params(v, new_key))
    else:
      items[new_key] = v
  return items

def flax_to_torch(flat_flax, torch_model):
  for name, param in torch_model.named_parameters():
    # Normalize naming
    name_fixed = name.replace('layers.', 'layers_')
    name_fixed = name_fixed.replace('blocks.', 'blocks_')

    flax_key = None

    if name == "cls_token":
      flax_key = "cls_token"

    elif name.endswith("weight"):
      # Try Linear/Conv kernels
      flax_key = name_fixed.replace("weight", "kernel")
      if flax_key not in flat_flax:
        # Try LayerNorm scale
        flax_key = name_fixed.replace("weight", "scale")

    elif name.endswith("bias"):
      flax_key = name_fixed  # bias names usually match directly

    if flax_key is None or flax_key not in flat_flax:
      print(f"[WARN] Missing weights for {name} (flax_key={flax_key})")
      continue

    # Load array
    array = np.array(flat_flax[flax_key])
    tensor = torch.tensor(array)

    # Handle Conv2d kernel
    if param.ndim == 4:
      # Flax: [H, W, in, out] → Torch: [out, in, H, W]
      if tensor.ndim == 5 and tensor.shape[0] == 1:  # Sometimes an extra batch dim
        tensor = tensor[0]
      tensor = tensor.permute(3, 2, 0, 1)

    # Handle Linear kernels
    elif param.ndim == 2:
      if tensor.ndim == 2:
        # Dense: [in, out] → [out, in]
        tensor = tensor.T
      elif tensor.ndim == 3:
        # DenseGeneral
        if param.shape[0] == tensor.shape[-1] * tensor.shape[-2]:  # Q/K/V projection
          tensor = tensor.reshape(tensor.shape[0], -1).T
        else:  # Output projection
          tensor = tensor.reshape(-1, tensor.shape[-1]).T
      else:
        raise ValueError(f"Unexpected kernel shape {tensor.shape} for {name}")

    # Reshape if needed (bias, cls_token, norm, etc.)
    tensor = tensor.reshape(param.shape)

    with torch.no_grad():
      param.copy_(tensor)

    print(f"Loaded {name} from {flax_key}")

restored_params = recover_tree(np.load("pretrain_rvm_large16_256_175558463.npz", allow_pickle=False))

flat_flax = flatten_flax_params(restored_params)
flax_to_torch(flat_flax, model)

In [None]:
# @title Feature extraction function

PATCH_SIZE = 16

def extract_features(model, video):
  video = video.astype(np.float32) / 255.0
  video = torch.as_tensor(video)
  h, w = video.shape[1] // PATCH_SIZE, video.shape[2] // PATCH_SIZE  # feature resolution

  model_state = None
  features = []
  for t in range(video.shape[0]):
    feature, model_state = model(video[t][None], model_state)
    feature, cls_token =feature[0, 1:, :], feature[0, 0:1, :]
    feature = feature.reshape(h, w, -1)
    feature = feature.detach().numpy()
    features.append(feature)
  features = np.stack(features, axis=0)
  return features

In [None]:
# @title Label propagation functions

def draw_labelmap_np(img, pt, sigma=0.5):
  # Draw a 2D gaussian
  # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py

  # Check that any part of the gaussian is in-bounds
  ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
  br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
  if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0):
    # If not, just return the image as is
    return img

  # Generate gaussian
  size = 6 * sigma + 1
  x = np.arange(0, size, 1, float)
  y = x[:, np.newaxis]
  x0 = y0 = size // 2
  g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))

  # Usable gaussian range
  g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
  g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
  # Image range
  img_x = max(0, ul[0]), min(br[0], img.shape[1])
  img_y = max(0, ul[1]), min(br[1], img.shape[0])

  img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
  return img

def mask2heatmap(mask, h, w, height, width):
  """Convert segmentation mask to heatmap (resize and one-hot encode)"""
  label_map = np.unique(mask)
  mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
  heatmap = np.stack([mask == l for l in label_map], axis=-1).astype(np.float32)
  heatmap = cv2.resize(heatmap, (w, h), interpolation=cv2.INTER_LINEAR)
  return heatmap, label_map

def pose2heatmap(pose, h, w, height, width):
  """Convert pose to heatmap (resize and one-hot encode)"""
  n_class = pose.shape[0]
  coord = pose * np.array([w / width, h / height])
  heatmap = np.zeros((h, w, n_class + 1), dtype=np.float32)
  for i in range(n_class):
    heatmap[..., i + 1] = draw_labelmap_np(np.zeros((h, w)), coord[i])
  heatmap[..., 0] = heatmap.sum(axis=-1) == 0
  return heatmap

def process_pose(preds, h, w, ori_height, ori_width, topk=5):
  current_coords, jnt_visibles = [], []
  for t in range(preds.shape[0]):
    pred = preds[t][..., 1:]
    flatlbls = pred.flatten(0, 1)
    vals, ids = torch.topk(flatlbls, k=topk, dim=0)
    vals /= vals.sum(0)[None]
    xx, yy = ids % pred.shape[1], ids // pred.shape[1]
    current_coord = torch.stack([(xx * vals).sum(0), (yy * vals).sum(0)], dim=0)
    current_coord[0, :] = current_coord[0, :] / w * ori_width
    current_coord[1, :] = current_coord[1, :] / h * ori_height
    current_coord[:, flatlbls.sum(0) == 0] = -1
    current_coord = current_coord.cpu().numpy().transpose(1, 0)
    jnt_visible = (current_coord[:, 0] >= 0).astype(float)  # (n_class)
    current_coords.append(current_coord)
    jnt_visibles.append(jnt_visible)
  current_coords = np.stack(current_coords, axis=0)
  jnt_visibles = np.stack(jnt_visibles, axis=0)
  return current_coords, jnt_visibles

def process_segmentation(preds, lbl_map, height, width, ori_height, ori_width):
  pred_lbls = []
  for t in range(preds.shape[0]):
    pred = preds[t]
    pred = pred.cpu().numpy()
    # Upsample predicted soft label maps
    pred_dist = cv2.resize(pred, (width, height))[:]
    # Argmax to get the hard label for index
    pred_lbl = np.argmax(pred_dist, axis=-1)
    pred_lbl = np.array(lbl_map, dtype=np.int32)[pred_lbl]
    pred_lbl = cv2.resize(pred_lbl, (ori_width, ori_height), interpolation=cv2.INTER_NEAREST_EXACT)
    pred_lbls.append(pred_lbl)
  pred_lbls = np.stack(pred_lbls, axis=0)
  return pred_lbls

def label_propagation(feats, heatmap, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False):
  """Propagation of the heatmap based on feature similarity."""

  # Creates a mask indicating valid neighbors for each grid element.
  h, w = feats.shape[1], feats.shape[2]
  gx, gy = torch.meshgrid(torch.arange(0, h), torch.arange(0, w), indexing="ij")  # (h, w)
  D = ((gx[None, None, :, :] - gx[:, :, None, None])**2 + (gy[None, None, :, :] - gy[:, :, None, None])**2).float() ** 0.5
  D = (D < radius).float().to('cuda')
  D[D == 0] = -1e10
  D[D == 1] = 0
  D = D.permute(2, 3, 0, 1)  # (h, w, h, w)

  # The queue stores the context frames
  que = queue.Queue(n_context)
  for _ in range(n_context):
    que.put([feats[0], heatmap])

  preds = []
  for t in range(feats.shape[0]):
    # Use first and previous frames as context
    ctx_feats = torch.stack([feats[0]] + [pair[0] for pair in que.queue])
    ctx_lbls = torch.stack([heatmap] + [pair[1] for pair in que.queue])

    aff = torch.einsum('hwc, tmnc -> hwtmn', feats[t], ctx_feats) / temperature  # (h, w, n_context+1, h, w)
    if restrict_neighborhood:
      aff[:, :, 1:] += D[:, :, None]  # (h, w, n_context+1, h, w)
    aff = aff.view(aff.shape[0], aff.shape[1], -1)  # (h, w, n_context+1 * h * w)

    weights, ids = torch.topk(aff, topk, dim=-1)  # (h, w, topk), (h, w, topk)
    weights = F.softmax(weights, dim=-1)  # (h, w, topk)
    ctx_lbls = ctx_lbls.view(-1, ctx_lbls.shape[-1])  # (n_context+1 * h * w, n_class)
    pred = torch.einsum('hwlk, hwl -> hwk', ctx_lbls[ids], weights) # (h, w, n_class)

    if que.qsize() == n_context:
      que.get()
    que.put([feats[t], pred])

    if norm_mask:
      pred -= pred.min(-1)[0][..., None]
      pred /= pred.max(-1)[0][..., None]

    preds.append(pred)
  preds = torch.stack(preds)
  return preds

In [None]:
# @title PCA visualization and k-means clustering {form-width: "20%"}

features = extract_features(model, noise_video)
print(features.shape)

pca_video = visualize_pca(features, noise_video)
kmeans_video = visualize_kmeans(features, noise_video)
mixed_video = 0.5 * noise_video / 255.0 + 0.5 * kmeans_video
video_titles = ['pca', 'kmeans', 'mixed']
media.show_videos([pca_video, kmeans_video, mixed_video], titles=video_titles, height=128, codec='gif', fps=16, columns=3)

### Quantitative Evaluation

In [None]:
# @title Download DAVIS-2017 dataset

%cd /content
!git clone https://github.com/davisvideochallenge/davis-2017
%cd /content/davis-2017
!./data/get_davis.sh

In [None]:
# @title Download JHMDB dataset

%cd /content
!wget https://storage.googleapis.com/representations4d/datasets/jhmdb.zip
!unzip jhmdb.zip

In [None]:
# @title Download VIP dataset

%cd /content
!wget https://storage.googleapis.com/representations4d/datasets/VIP.zip
!unzip VIP.zip

In [None]:
# @title Dataset functions

def create_davis_dataset(root_path):
  """DAVIS dataset, including fields required for video segmentation evaluation."""
  for video_id in tf.io.gfile.GFile(os.path.join(root_path, 'ImageSets/2017/val.txt'), 'r'):
    video_id = video_id.strip()
    frame_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'JPEGImages/480p', video_id, '*.jpg')))
    frames = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb')).convert('RGB')) for f in frame_files])
    label_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'Annotations/480p', video_id, '*.png')))
    labels = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb'))) for f in label_files])
    yield {'video': frames, 'mask': labels, 'video_id': video_id}

def create_jhmdb_dataset(jhmdb_path):
  """JHMDB dataset, including fields required for PCK evaluation."""
  video_ids = []
  for file in os.listdir(os.path.join(jhmdb_path, 'splits')):
    if file.endswith('split1.txt'):
      video_folder = '_'.join(file.split('_')[:-2])
      lines = open(os.path.join(jhmdb_path, 'splits', file), 'r').readlines()
      for line in lines:
        video_name, traintest = line.split()
        if int(traintest) == 2:
          video_id = os.path.join(video_folder, video_name.split('.')[0])
          video_ids.append(video_id)

  random.shuffle(video_ids)
  for video_id in video_ids:
    joints = os.path.join(jhmdb_path, 'joint_positions', video_id, 'joint_positions.mat')
    pose = sio.loadmat(joints)['pos_img'] - 1  # matlab -> python
    pose = pose.astype(np.float32).transpose(2, 1, 0)  # (num_frames, num_points, 2)
    frame_files = sorted(glob.glob(os.path.join(jhmdb_path, 'Rename_Images', video_id, '*.png')))
    frames = np.stack([np.array(Image.open(open(f, 'rb')).convert('RGB')) for f in frame_files])
    yield {'video': frames, 'pose': pose, 'video_id': video_id}

def create_vip_dataset(root_path):
  """VIP dataset, including fields required for video segmentation evaluation."""
  for video_id in tf.io.gfile.GFile(os.path.join(root_path, 'lists/val_videos.txt'), 'r'):
    video_id = video_id.strip()
    frame_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'Images', video_id, '*.jpg')))
    frames = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb')).convert('RGB')) for f in frame_files])
    label_files = sorted(tf.io.gfile.glob(os.path.join(root_path, 'Annotations/Category_ids', video_id, '*.png')))
    labels = np.stack([np.array(Image.open(tf.io.gfile.GFile(f, 'rb'))) for f in label_files])
    yield {'video': frames, 'mask': labels, 'video_id': video_id}

In [None]:
# @title DAVIS video segmentation evaluation metric
# https://github.com/davisvideochallenge/davis2017-evaluation/blob/master/davis2017/metrics.py

import math
import cv2

def db_eval_iou(annotation, segmentation, void_pixels=None):
  """ Compute region similarity as the Jaccard Index.
  Arguments:
      annotation   (ndarray): binary annotation   map.
      segmentation (ndarray): binary segmentation map.
      void_pixels  (ndarray): optional mask with void pixels

  Return:
      jaccard (float): region similarity
  """
  assert annotation.shape == segmentation.shape, \
      f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.'
  annotation = annotation.astype(bool)
  segmentation = segmentation.astype(bool)

  if void_pixels is not None:
    assert annotation.shape == void_pixels.shape, \
        f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.'
    void_pixels = void_pixels.astype(bool)
  else:
    void_pixels = np.zeros_like(segmentation)

  # Intersection between all sets
  inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1))
  union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1))

  j = inters / union
  if j.ndim == 0:
    j = 1 if np.isclose(union, 0) else j
  else:
    j[np.isclose(union, 0)] = 1
  return j


def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008):
  assert annotation.shape == segmentation.shape
  if void_pixels is not None:
    assert annotation.shape == void_pixels.shape
  if annotation.ndim == 3:
    n_frames = annotation.shape[0]
    f_res = np.zeros(n_frames)
    for frame_id in range(n_frames):
      void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ]
      f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th)
  elif annotation.ndim == 2:
    f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th)
  else:
    raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions')
  return f_res


def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008):
  """
  Compute mean,recall and decay from per-frame evaluation.
  Calculates precision/recall for boundaries between foreground_mask and
  gt_mask using morphological operators to speed it up.

  Arguments:
      foreground_mask (ndarray): binary segmentation image.
      gt_mask         (ndarray): binary annotated image.
      void_pixels     (ndarray): optional mask with void pixels

  Returns:
      F (float): boundaries F-measure
  """
  assert np.atleast_3d(foreground_mask).shape[2] == 1
  if void_pixels is not None:
    void_pixels = void_pixels.astype(bool)
  else:
    void_pixels = np.zeros_like(foreground_mask).astype(bool)

  bound_pix = bound_th if bound_th >= 1 else \
      np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))

  # Get the pixel boundaries of both masks
  fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels))
  gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels))

  from skimage.morphology import disk

  # fg_dil = binary_dilation(fg_boundary, disk(bound_pix))
  fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))
  # gt_dil = binary_dilation(gt_boundary, disk(bound_pix))
  gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8))

  # Get the intersection
  gt_match = gt_boundary * fg_dil
  fg_match = fg_boundary * gt_dil

  # Area of the intersection
  n_fg = np.sum(fg_boundary)
  n_gt = np.sum(gt_boundary)

  # % Compute precision and recall
  if n_fg == 0 and n_gt > 0:
    precision = 1
    recall = 0
  elif n_fg > 0 and n_gt == 0:
    precision = 0
    recall = 1
  elif n_fg == 0 and n_gt == 0:
    precision = 1
    recall = 1
  else:
    precision = np.sum(fg_match) / float(n_fg)
    recall = np.sum(gt_match) / float(n_gt)

  # Compute F measure
  if precision + recall == 0:
    F = 0
  else:
    F = 2 * precision * recall / (precision + recall)

  return F


def _seg2bmap(seg, width=None, height=None):
  """
  From a segmentation, compute a binary boundary map with 1 pixel wide
  boundaries.  The boundary pixels are offset by 1/2 pixel towards the
  origin from the actual segment boundary.
  Arguments:
      seg     : Segments labeled from 1..k.
      width   : Width of desired bmap  <= seg.shape[1]
      height  : Height of desired bmap <= seg.shape[0]
  Returns:
      bmap (ndarray): Binary boundary map.
   David Martin <dmartin@eecs.berkeley.edu>
   January 2003
  """

  seg = seg.astype(bool)
  seg[seg > 0] = 1

  assert np.atleast_3d(seg).shape[2] == 1

  width = seg.shape[1] if width is None else width
  height = seg.shape[0] if height is None else height

  h, w = seg.shape[:2]

  ar1 = float(width) / float(height)
  ar2 = float(w) / float(h)

  assert not (
      width > w | height > h | abs(ar1 - ar2) > 0.01
  ), "Can't convert %dx%d seg to %dx%d bmap." % (w, h, width, height)

  e = np.zeros_like(seg)
  s = np.zeros_like(seg)
  se = np.zeros_like(seg)

  e[:, :-1] = seg[:, 1:]
  s[:-1, :] = seg[1:, :]
  se[:-1, :-1] = seg[1:, 1:]

  b = seg ^ e | seg ^ s | seg ^ se
  b[-1, :] = seg[-1, :] ^ e[-1, :]
  b[:, -1] = seg[:, -1] ^ s[:, -1]
  b[-1, -1] = 0

  if w == width and h == height:
    bmap = b
  else:
    bmap = np.zeros((height, width))
    for x in range(w):
      for y in range(h):
        if b[y, x]:
          j = 1 + math.floor((y - 1) + height / h)
          i = 1 + math.floor((x - 1) + width / h)
          bmap[j, i] = 1

  return bmap

In [None]:
# @title JHMDB pck evaluation metric
# https://github.com/Liusifei/UVC/blob/jhmdb/eval_pck.py

def compute_human_boxes(gts, jnt_visible_set):
  human_boxes = []
  for nowgt, jnt_visible in zip(gts, jnt_visible_set):
    now_boxes = np.zeros(nowgt.shape[0])
    for t in range(nowgt.shape[0]):
      visible_pts = nowgt[t, jnt_visible[t] == 1]
      if visible_pts.size:
        min_pt, max_pt = visible_pts.min(axis=0), visible_pts.max(axis=0)
        now_boxes[t] = 0.6 * np.linalg.norm(max_pt - min_pt)
    human_boxes.append(now_boxes)
  return human_boxes

def compute_distances(gts, preds, human_boxes, jnt_visible_set):
  distAll = {pidx: np.array([]) for pidx in range(15)}
  for nowgt, predres, now_boxes, jnt_visible in zip(gts, preds, human_boxes, jnt_visible_set):
    for i in range(nowgt.shape[1]):
      for t in range(1, nowgt.shape[0]):
        if jnt_visible[t, i]:
          distAll[i] = np.append(distAll[i], np.linalg.norm(predres[t, i] - nowgt[t, i]) / now_boxes[t])
  return distAll

def computePCK(distAll, distThresh):
  pckAll = np.zeros(len(distAll))
  for pidx, distances in distAll.items():
    pckAll[pidx] = 100.0 * np.sum(distances <= distThresh) / len(distances)
  pck = np.mean(pckAll)
  return pck, pckAll

In [None]:
# @title VIP video segmentation evaluation metric

def fast_hist(a, b, n):
  k = (a >= 0) & (a < n)
  return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)

def compute_iou(hist):
  classes = ['background', 'hat', 'hair', 'sun-glasses', 'upper-clothes', 'dress',
           'coat', 'socks', 'pants', 'gloves', 'scarf', 'skirt', 'torso-skin',
           'face', 'right-arm', 'left-arm', 'right-leg', 'left-leg', 'right-shoe', 'left-shoe']

  num_cor_pix = np.diag(hist)  # num of correct pixels
  num_gt_pix = hist.sum(1)  # num of gt pixels
  union = num_gt_pix + hist.sum(0) - num_cor_pix
  iou_per_class = {}
  for i in range(len(classes)):
    iou_per_class[classes[i]] = num_cor_pix[i] / union[i]
  iou = num_cor_pix / (num_gt_pix + hist.sum(0) - num_cor_pix)
  iou = np.nanmean(iou)
  return iou, iou_per_class

In [None]:
# @title DAVIS evaluation function

def evaluate_davis(model, extract_feature_function, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16, params=None):
  davis_dataset = create_davis_dataset(dataset_path)
  h, w = height // patch_size, width // patch_size  # feature resolution

  n_max_class = 10
  j_metrics, f_metrics, fj_metrics, counters = [], [], [], []
  for sample in tqdm.tqdm(davis_dataset):
    ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

    # Extract features
    video = media.resize_video(sample['video'], (height, width))
    if params is None:
      feats = extract_feature_function(model, video)  # PyTorch models
    else:
      feats = extract_feature_function(model, params, video)  # Jax models
    if not isinstance(feats, torch.Tensor):
      feats = torch.tensor(feats).cuda()
    feats = torch.nn.functional.normalize(feats, dim=-1)

    # Prepare downscaled first frame segmentation (resize and one-hot encode)
    lbls_small, lbl_map = mask2heatmap(sample['mask'][0], h, w, height, width)
    lbls_small = torch.tensor(lbls_small).cuda()

    pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)

    pred_lbls = process_segmentation(pred_lbls, lbl_map, height, width, ori_height, ori_width)

    n_class = int(sample['mask'].max())  # Get the number of objects in the segmentation map

    masks = sample['mask'][1:-1]
    all_gt_masks = np.eye(n_class + 1)[masks]
    all_gt_masks = all_gt_masks[..., 1:]  # Remove background class

    masks = pred_lbls[1:-1]
    all_res_masks = np.eye(n_class + 1)[masks]
    all_res_masks = all_res_masks[..., 1:]  # Remove background class

    num_frames = all_gt_masks.shape[0]
    j_metrics_res = np.zeros((n_class, num_frames))
    f_metrics_res = np.zeros((n_class, num_frames))
    for ii in range(n_class):
      j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[..., ii], all_res_masks[..., ii], None)
      f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[..., ii], all_res_masks[..., ii], None)

    JM, FM, FJM = np.zeros((n_max_class,)), np.zeros((n_max_class,)), np.zeros((n_max_class,))
    for ii in range(n_class):
      JM[ii] = np.nanmean(j_metrics_res[ii])
      FM[ii] = np.nanmean(f_metrics_res[ii])
      FJM[ii] = (JM[ii] + FM[ii]) / 2.
    counter = np.zeros((n_max_class,))
    counter[:n_class] = 1

    j_metrics.append(JM)
    f_metrics.append(FM)
    fj_metrics.append(FJM)
    counters.append(counter)

  j_metrics = np.array(j_metrics)
  f_metrics = np.array(f_metrics)
  fj_metrics = np.array(fj_metrics)
  counters = np.array(counters)

  fj_metric = (fj_metrics * counters).sum() / counters.sum()
  j_metric = (j_metrics * counters).sum() / counters.sum()
  f_metric = (f_metrics * counters).sum() / counters.sum()

  print('')
  print('J&F-Mean',fj_metric)
  print('J-Mean', j_metric)
  print('F-Mean', f_metric)

In [None]:
# @title JHMDB evaluation function

def evaluate_jhmdb(model, extract_feature_function, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16, params=None):
  jhmdb_dataset = create_jhmdb_dataset(jhmdb_path=dataset_path)
  h, w = height // patch_size, width // patch_size  # feature resolution

  gts, preds, jnt_visible_set = [], [], []
  for sample in tqdm.tqdm(jhmdb_dataset):
    ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

    # Extract features
    video = media.resize_video(sample['video'], (height, width))
    if params is None:
      feats = extract_feature_function(model, video)  # PyTorch models
    else:
      feats = extract_feature_function(model, params, video)  # Jax models
    if not isinstance(feats, torch.Tensor):
      feats = torch.tensor(feats).cuda()
    feats = torch.nn.functional.normalize(feats, dim=-1)

    # Prepare downscaled first frame heatmap (resize and one-hot encode)
    lbls_small = pose2heatmap(sample['pose'][0], h, w, ori_height, ori_width)
    lbls_small = torch.tensor(lbls_small).cuda()

    pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)

    pred_lbls, jnt_visible = process_pose(pred_lbls, h, w, ori_height, ori_width)
    preds.append(pred_lbls)
    jnt_visible_set.append(jnt_visible)
    gts.append(sample['pose'])

  human_boxes = compute_human_boxes(gts, jnt_visible_set)
  distAll = compute_distances(gts, preds, human_boxes, jnt_visible_set)

  for threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
    print(f"PCK@{threshold}: {computePCK(distAll, threshold)[0]}")

In [None]:
# @title VIP evaluation function

def evaluate_vip(model, extract_feature_function, dataset_path='/content/VIP/', height=448, width=880, patch_size=16, params=None):
  vip_dataset = create_vip_dataset(root_path=dataset_path)
  h, w = height // patch_size, width // patch_size  # feature resolution

  n_class = 20
  hist = np.zeros((n_class, n_class))
  for sample in tqdm.tqdm(vip_dataset):
    ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

    # Extract features
    video = media.resize_video(sample['video'], (height, width))
    if params is None:
      feats = extract_feature_function(model, video)  # PyTorch models
    else:
      feats = extract_feature_function(model, params, video)  # Jax models
    if not isinstance(feats, torch.Tensor):
      feats = torch.tensor(feats).cuda()
    feats = torch.nn.functional.normalize(feats, dim=-1)

    # Prepare downscaled first frame segmentation (resize and one-hot encode)
    lbls_small, lbl_map = mask2heatmap(sample['mask'][0], h, w, height, width)
    lbls_small = torch.tensor(lbls_small).cuda()

    pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=10, radius=20, restrict_neighborhood=True, norm_mask=False)

    pred_lbls = process_segmentation(pred_lbls, lbl_map, height, width, ori_height, ori_width)

    for t in range(pred_lbls.shape[0]):
      hist += fast_hist(sample['mask'][t], pred_lbls[t], n_class)

  iou, iou_per_class = compute_iou(hist)

  print('')
  for key in iou_per_class:
    print('%-15s: %f' % (key, iou_per_class[key]))
  print ('mean IoU', iou)

In [None]:
# @title DAVIS evaluation

evaluate_davis(model, extract_features, dataset_path='/content/davis-2017/DAVIS/', height=480, width=880, patch_size=16, params=restored_params)

In [None]:
# @title JHMDB evaluation

evaluate_jhmdb(model, extract_features, dataset_path='/content/jhmdb/', height=320, width=320, patch_size=16, params=restored_params)

In [None]:
# @title VIP evaluation

evaluate_vip(model, extract_features, dataset_path='/content/VIP/', height=448, width=880, patch_size=16, params=restored_params)

In [None]:
# @title Visualize video segmentation on DAVIS

model_name = 'rvm'

os.makedirs(f'/content/davis_results/{model_name}', exist_ok=True)

davis_dataset = create_davis_dataset('/content/davis-2017/DAVIS')
height, width = 480, 880  # video resolution during processing
h, w = height // PATCH_SIZE, width // PATCH_SIZE  # feature shape

video_names = [
    'bike-packing',
    'bmx-trees',
    'gold-fish',
    'horsejump-high',
    'judo',
    'lab-coat',
    'pigs'
]

for i, sample in enumerate(davis_dataset):
  if sample['mask'].max() <= 1 or sample['video_id'] not in video_names:
    continue
  print(i, sample['video_id'])

  # Extract features
  video = media.resize_video(sample['video'], (height, width))
  feats = extract_features(model, restored_params, video)
  feats = feats / jnp.linalg.norm(feats, axis=-1, keepdims=True)

  # Prepare downscaled first frame segmentation (resize and one-hot encode)
  lbls_small, lbl_map = mask2heatmap(sample['mask'][0], h, w, height, width)
  pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)
  pred_mask = heatmap2mask(pred_lbls, lbl_map, height, width, height, width)

  seg_video = segmentations_to_video(pred_mask)
  mixed_video = 0.5 * video / 255.0 + 0.5 * seg_video
  media.show_video(mixed_video, height=128, codec='gif', fps=16)

  filename = f'/content/davis_results/{model_name}/{sample['video_id']}.mp4'
  media.write_video(filename, mixed_video, fps=16)

  num_frames = video.shape[0]
  for t in [0, num_frames // 4, num_frames // 2, num_frames - 1]:
    image = Image.fromarray((np.array(mixed_video[t]) * 255.0).astype(np.uint8))
    image.save(f'/content/davis_results/{model_name}/{sample['video_id']}_{t}.pdf')

In [None]:
# @title Visualize pose tracking on JHMDB

model_name = 'rvm'

os.makedirs(f'/content/jhmdb_results/{model_name}', exist_ok=True)

jhmdb_dataset = create_jhmdb_dataset(jhmdb_path='/content/jhmdb/')
height, width = 320, 320  # video resolution during processing
h, w = height // PATCH_SIZE, width // PATCH_SIZE  # feature shape

num_videos_to_show = 5
for i, sample in enumerate(jhmdb_dataset):
  if i >= num_videos_to_show: break

  print(i, sample['video_id'])
  video_id = sample['video_id'].split('/')[-1]
  ori_height, ori_width = sample['video'].shape[1], sample['video'].shape[2]

  # Extract features
  video = media.resize_video(sample['video'], (height, width))
  feats = extract_features(model, restored_params, video)
  feats = feats / jnp.linalg.norm(feats, axis=-1, keepdims=True)

  # Prepare downscaled first frame heatmap (resize and one-hot encode)
  lbls_small = pose2heatmap(sample['pose'][0], h, w, ori_height, ori_width)
  pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)
  pred_pose, jnt_visible = heatmap2pose(pred_lbls, h, w, ori_height, ori_width)

  pose_video = []
  for t in range(sample['video'].shape[0]):
    pose_video.append(vis_pose(sample['video'][t], pred_pose[t]))
  pose_video = np.stack(pose_video, axis=0)
  gt_video = []
  for t in range(sample['video'].shape[0]):
    gt_video.append(vis_pose(sample['video'][t], sample['pose'][t]))
  gt_video = np.stack(gt_video, axis=0)
  videos = np.concatenate([pose_video, gt_video], axis=2)
  media.show_video(videos, height=128, fps=16, codec='gif')

  filename = f'/content/jhmdb_results/{model_name}/{video_id}.mp4'
  media.write_video(filename, pose_video, fps=16)

  num_frames = video.shape[0]
  for t in [0, num_frames // 4, num_frames // 2, num_frames - 1]:
    image = Image.fromarray(pose_video[t])
    image.save(f'/content/jhmdb_results/{model_name}/{video_id}_{t}.pdf')

In [None]:
# @title Visualize video segmentation on VIP

model_name = 'rvm'

os.makedirs(f'/content/vip_results/{model_name}', exist_ok=True)

vip_dataset = create_vip_dataset(root_path='/content/VIP/')
height, width = 448, 880  # video resolution during processing
h, w = height // PATCH_SIZE, width // PATCH_SIZE  # feature shape

num_videos_to_show = 5
for i, sample in enumerate(vip_dataset):
  if sample['mask'].max() <= 5:
    continue
  if i >= num_videos_to_show: break

  print(i, sample['video_id'])

  # Extract features
  video = media.resize_video(sample['video'], (height, width))
  feats = extract_features(model, restored_params, video)
  feats = feats / jnp.linalg.norm(feats, axis=-1, keepdims=True)

  # Prepare downscaled first frame segmentation (resize and one-hot encode)
  lbls_small, lbl_map = mask2heatmap(sample['mask'][0], h, w, height, width)
  pred_lbls = label_propagation(feats, lbls_small, n_context=20, temperature=0.7, topk=7, radius=20, restrict_neighborhood=True, norm_mask=False)
  pred_mask = heatmap2mask(pred_lbls, lbl_map, height, width, height, width)

  seg_video = segmentations_to_video(pred_mask)
  mixed_video = 0.5 * video / 255.0 + 0.5 * seg_video
  media.show_video(mixed_video, height=128, codec='gif', fps=16)

  filename = f'/content/vip_results/{model_name}/{sample['video_id']}.mp4'
  media.write_video(filename, mixed_video, fps=16)

  num_frames = video.shape[0]
  for t in [0, num_frames // 4, num_frames // 2, num_frames - 1]:
    image = Image.fromarray((np.array(mixed_video[t]) * 255.0).astype(np.uint8))
    image.save(f'/content/vip_results/{model_name}/{sample['video_id']}_{t}.pdf')