# V-JEPA 2-AC — Tutorial Notebook

**Goal:** Make it simple to plug in a dataset and:
- Understand inputs
- Run **one-step** and **open-loop** predictions
- Decode tokens back to images
- Save nice **GT (ground-truth) vs. Pred** videos with **automatic naming** in `videos/`


In [None]:
# ----------------------------------------
# 📦 Imports & Global Config
# ----------------------------------------
import os, sys, time, math, pickle, shutil, datetime, pathlib
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import imageio

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.utils import make_grid
from scipy.spatial.transform import Rotation as R

# Compute device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"CUDA available: {torch.cuda.is_available()} | device: {DEVICE}")

# Visualization defaults
plt.rcParams["figure.figsize"] = (5, 5)
plt.rcParams["axes.grid"] = False


CUDA available: True | device: cuda


## 📦 V-JEPA 2-AC — What Data Does It Use?

V-JEPA 2-AC is an **action-conditioned latent world model** trained to predict future features from robot interaction data. It takes three inputs:

### 🖼️ 1. Images

Two RGB images captured **back-to-back in time** (roughly 4 FPS apart).

- Each image should be:
  - Shape: [256, 256, 3] (Height × Width × Channels)
  - Type: uint8 (pixel values from 0 to 255)

- Before passing into the model, convert the pair into a PyTorch tensor:
  - Shape: [1, 3, 2, 256, 256] →  [Batch, Channels, Time, Height, Width]
  - Type: float32
  - Normalize pixel values to [0, 1] by dividing by 255

> 💡 The model was trained using a **fixed exocentric camera**. Large camera movements may hurt performance.

### 🤖 2. State

The robot’s current state, represented as a **7-dimensional vector**:

