<p align="center">
  <h1 align="center">TRecViT: A Recurrent Video Transformer</h1>
  <p align="center">
    Viorica Patraucean
    ·
    Xu Owen He
    ·
    Joseph Heyward
    ·
    Chuhan Zhang
    ·
    Mehdi S. M. Sajjadi
    ·
    George-Cristian Muraru
    ·
    Artem Zholus
    ·
    Mahdi Karami
    .
    Ross Goroshin
    .
    Yutian Chen
    .
    Simon Osindero
    .
    Joao Carreira
    .
    Razvan Pascanu
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2412.14294">Paper</a> | <a href="https://github.com/deepmind/trecvit">GitHub</a></h3>
  <div align="center"></div>
</p>

<p align="center">
  <a href="">
    <img src="https://storage.googleapis.com/trecvit/model_checkpoints/diagram.png" alt="Logo" width="50%">
  </a>
</p>

# Install Dependencies


In [None]:
!git clone https://github.com/google-deepmind/trecvit.git
%cd trecvit
!pip install .
!wget https://storage.googleapis.com/trecvit/model_checkpoints/trecvit_B_k400.npz
!pip install mediapy

In [None]:
import jax
import jax.numpy as jnp
import mediapy as media
from trecvit import trecvit_model
from trecvit import utils

# Init Model

In [None]:
num_frames = 32
model = trecvit_model.get_model(num_frames=num_frames)

x = jnp.zeros((1, num_frames, 224, 224, 3)).astype(jnp.float32)
params = model.init(jax.random.key(0), x)
path = 'trecvit_B_k400.npz'
params = utils.load_ckpt(params, path)

@jax.jit
def forward(params, x):
  return model.apply(params, x)


# Run Inference

In [None]:
frames = media.read_video('figures/example.mp4')
media.show_video(frames)

In [None]:
out = forward(params, frames[None]/255)
label = out['probs'].argmax(axis=-1, keepdims=True)
print('video_action_label: ', label)

# Expected label: 213 - petting a cat