# V-JEPA2 → CLIP Alignment (PyTorch-Native Rewrite)

A clean rewrite that embraces PyTorch's internal data structures and vectorization:

- **PEFT LoRA** instead of custom LoRA implementation
- **Pre-allocated circular buffer** for memory bank (no Python list append/pop)
- **Batch video decoding** with decord's native batch API
- **torch.compile** on the projection head
- **Native cross-device autograd** (no custom Function needed)
- **Vectorized contrastive loss** with einsum
- **Proper DataLoader workers** with prefetching


In [1]:
%pip install -q peft accelerate transformers decord wandb

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import math
import time
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModel, AutoVideoProcessor, CLIPModel, CLIPProcessor
from peft import LoraConfig, get_peft_model, TaskType
import wandb

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"  GPU{i}: {torch.cuda.get_device_name(i)}")

# Note: torch.set_float32_matmul_precision('high') can cause NaN issues
# with this model, so we leave it at default precision

PyTorch: 2.4.1+cu124
CUDA available: True
GPU count: 2
  GPU0: NVIDIA A40
  GPU1: NVIDIA A40


## Configuration

Using a dataclass for type safety and IDE support.

In [3]:
@dataclass
class Config:
    # Models
    vjepa_model: str = "facebook/vjepa2-vitl-fpc64-256"
    clip_model: str = "openai/clip-vit-large-patch14"

    # LoRA (PEFT)
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: list = field(default_factory=lambda: ["query", "key", "value"])

    # Video sampling
    num_frames: int = 8
    window_stride: int = 4
    max_windows_per_video: int = 16
    clip_sample_frames: int = 4

    # Data
    video_dir: str = "./videos"
    num_videos: Optional[int] = None  # None = use all
    test_split: float = 0.15

    # Training
    batch_size: int = 8
    epochs: int = 10
    lr: float = 5e-5  # Base LR (used for backward compatibility)
    lora_lr: float = 2e-5  # Lower LR for LoRA adapters
    proj_lr: float = 3e-4  # Faster LR for projection head
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    grad_clip: float = 0.35  # Tighter gradient clipping
    gradient_accum_steps: int = 2  # Larger effective batch, calmer grads

    # Contrastive learning
    temperature: float = 0.07
    label_smoothing: float = 0.05
    logit_scale_max: float = math.log(100)
    memory_bank_size: int = 512  # Total embeddings to keep
    lambda_align: float = 0.2    # Small alignment weight for stability

    # Projection head
    proj_hidden_dim: int = 1024
    proj_warmup_steps: int = 20  # Train only projection head for this many optimizer steps

    # System
    seed: int = 42
    num_workers: int = 4
    use_compile: bool = False  # torch.compile can cause NaN issues, disable by default
    use_bf16: bool = True
    gradient_checkpointing: bool = True

    # Wandb logging
    use_wandb: bool = True
    wandb_project: str = "vjepa-clip-alignment"
    wandb_run_name: Optional[str] = None  # Auto-generated if None

    # Devices (auto-configured)
    vjepa_device: str = "cuda:0"
    clip_device: str = "cuda:1" if torch.cuda.device_count() > 1 else "cuda:0"

cfg = Config()
print(cfg)



Config(vjepa_model='facebook/vjepa2-vitl-fpc64-256', clip_model='openai/clip-vit-large-patch14', lora_r=8, lora_alpha=16, lora_dropout=0.05, lora_target_modules=['query', 'key', 'value'], num_frames=8, window_stride=4, max_windows_per_video=16, clip_sample_frames=4, video_dir='./videos', num_videos=None, test_split=0.15, batch_size=8, epochs=10, lr=5e-05, weight_decay=0.01, warmup_ratio=0.1, grad_clip=0.5, temperature=0.07, memory_bank_size=512, lambda_align=0.0, proj_hidden_dim=1024, proj_warmup_steps=10, seed=42, num_workers=4, use_compile=False, use_bf16=True, gradient_checkpointing=True, use_wandb=True, wandb_project='vjepa-clip-alignment', wandb_run_name=None, vjepa_device='cuda:0', clip_device='cuda:1')


## Circular Memory Bank (Vectorized)

Pre-allocated tensor with O(1) insert via index tracking. No Python list operations.

In [4]:
class CircularMemoryBank(nn.Module):
    """Pre-allocated circular buffer for contrastive learning.
    
    All operations are vectorized tensor ops - no Python loops or list appends.
    Uses scatter_ for clean wraparound handling.
    """
    
    def __init__(self, capacity: int, embed_dim: int, device: torch.device):
        super().__init__()
        self.capacity = capacity
        self.embed_dim = embed_dim
        
        # Pre-allocate the buffer (not a parameter, just persistent state)
        self.register_buffer('bank', torch.zeros(capacity, embed_dim, device=device))
        self.register_buffer('ptr', torch.zeros(1, dtype=torch.long, device=device))
        self.register_buffer('count', torch.zeros(1, dtype=torch.long, device=device))
    
    @torch.no_grad()
    def push(self, embeddings: torch.Tensor) -> None:
        """Add embeddings to the bank. embeddings: [B, D]"""
        B = embeddings.shape[0]
        
        # Normalize before storing
        embeddings = F.normalize(embeddings.float(), dim=-1)
        
        # Calculate write indices with modular arithmetic (handles wraparound)
        ptr = self.ptr.item()
        indices = (torch.arange(B, device=embeddings.device) + ptr) % self.capacity
        
        # Scatter write - handles wraparound automatically
        self.bank.index_copy_(0, indices, embeddings)
        
        # Update pointer and count
        self.ptr[0] = (ptr + B) % self.capacity
        self.count[0] = min(self.count.item() + B, self.capacity)
    
    def get_all(self) -> torch.Tensor:
        """Get all valid embeddings in the bank."""
        count = self.count.item()
        if count == 0:
            return torch.zeros(0, self.embed_dim, device=self.bank.device)
        return self.bank[:count].clone()
    
    @property
    def size(self) -> int:
        return self.count.item()
    
    def reset(self) -> None:
        self.ptr.zero_()
        self.count.zero_()
        self.bank.zero_()

## Vectorized Contrastive Loss

Using einsum for clarity and letting PyTorch optimize the underlying ops.

In [5]:
class InfoNCELoss(nn.Module):
    """Combined contrastive + alignment loss for embedding translation.

    Combines InfoNCE (distinguishes embeddings) with cosine alignment (pulls pairs together).
    Includes extensive numerical stability fixes.
    """

    def __init__(self, temperature: float = 0.07, lambda_align: float = 1.0, label_smoothing: float = 0.0, logit_scale_max: float = math.log(100)):
        super().__init__()
        # Clamp temperature to prevent numerical issues
        self.temperature = max(temperature, 0.01)
        self.lambda_align = lambda_align
        self.label_smoothing = label_smoothing
        self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / self.temperature)))
        self.logit_scale_max = logit_scale_max
        self.eps = 1e-8

    def forward(
        self,
        queries: torch.Tensor,      # [B, D] - projected V-JEPA embeddings
        keys: torch.Tensor,         # [B, D] - CLIP embeddings for this batch
        memory: torch.Tensor,       # [M, D] - memory bank embeddings (can be empty)
    ) -> dict[str, torch.Tensor]:
        """Compute combined loss with numerical stability.

        Returns dict with:
            - loss: combined weighted loss (for backward)
            - contrastive: InfoNCE component (for logging)
            - alignment: cosine alignment component (for logging)
            - acc: contrastive accuracy
            - cos: mean cosine similarity
        """
        B = queries.shape[0]
        device = queries.device

        # Force fp32 for all loss computation (critical for stability)
        queries = queries.float()
        keys = keys.float()
        if memory.numel() > 0:
            memory = memory.float()

        # Normalize with eps for numerical stability
        q = F.normalize(queries, dim=-1, eps=self.eps)  # [B, D]
        k = F.normalize(keys, dim=-1, eps=self.eps)     # [B, D]

        # Check for NaNs/Infs in inputs
        if not torch.isfinite(q).all():
            print(f"[WARN] Non-finite values in normalized queries")
            q = torch.nan_to_num(q, nan=0.0, posinf=1.0, neginf=-1.0)
        if not torch.isfinite(k).all():
            print(f"[WARN] Non-finite values in normalized keys")
            k = torch.nan_to_num(k, nan=0.0, posinf=1.0, neginf=-1.0)

        # === Alignment loss (pulls paired embeddings together) ===
        # Clamp cosine similarity to valid range [-1, 1]
        cos_sim = (q * k).sum(dim=-1)  # [B] - more stable than einsum for element-wise
        cos_sim = torch.clamp(cos_sim, -1.0 + self.eps, 1.0 - self.eps)
        align_loss = (1.0 - cos_sim).mean()

        # === Contrastive loss (distinguishes from negatives) ===
        M = memory.shape[0] if memory.numel() > 0 else 0

        if M > 0:
            # Normalize memory with same eps
            memory = F.normalize(memory, dim=-1, eps=self.eps)
            if not torch.isfinite(memory).all():
                print(f"[WARN] Non-finite values in memory bank")
                memory = torch.nan_to_num(memory, nan=0.0, posinf=1.0, neginf=-1.0)
            all_keys = torch.cat([memory, k], dim=0)  # [M+B, D]
        else:
            all_keys = k  # [B, D]

        # Compute logits with trainable logit scale (like CLIP)
        # Using @ instead of einsum for potentially better numerical properties
        logit_scale = self.logit_scale.clamp(-self.logit_scale_max, self.logit_scale_max).exp()
        logits = (q @ all_keys.T) * logit_scale  # [B, M+B]

        # Clamp logits to prevent overflow in softmax
        # exp(85) is near float32 max, so we clamp to [-50, 50] to be safe
        logits = torch.clamp(logits, -50.0, 50.0)

        # Check for NaNs in logits
        if not torch.isfinite(logits).all():
            print(f"[WARN] Non-finite logits detected, clamping")
            logits = torch.nan_to_num(logits, nan=0.0, posinf=50.0, neginf=-50.0)

        labels = torch.arange(M, M + B, device=device, dtype=torch.long)

        # Use label smoothing for additional stability
        contrastive_loss = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)

        # Final NaN check on losses
        if not torch.isfinite(contrastive_loss):
            print(f"[ERROR] Non-finite contrastive loss, setting to 0")
            contrastive_loss = torch.tensor(0.0, device=device, requires_grad=True)
        if not torch.isfinite(align_loss):
            print(f"[ERROR] Non-finite alignment loss, setting to 0")
            align_loss = torch.tensor(0.0, device=device, requires_grad=True)

        # === Combined loss ===
        total_loss = contrastive_loss + self.lambda_align * align_loss

        # === Metrics (no grad) ===
        with torch.no_grad():
            # Safe accuracy computation
            if torch.isfinite(logits).all():
                acc = (logits.argmax(dim=1) == labels).float().mean()
            else:
                acc = torch.tensor(0.0, device=device)

        return {
            'loss': total_loss,
            'contrastive': contrastive_loss.detach(),
            'alignment': align_loss.detach(),
            'acc': acc,
            'cos': cos_sim.mean().detach(),
        }



## Projection Head

Simple MLP, optionally compiled with `torch.compile`.

In [6]:
class ProjectionHead(nn.Module):
    """MLP to project V-JEPA embeddings to CLIP space."""
    
    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 1024):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.float()
        
        # Check input
        if not torch.isfinite(x).all():
            print(f"[PROJ] Input has NaN/Inf")
            x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
        
        # Manual normalization
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x = (x - mean) / (var.sqrt() + 1e-6)
        
        # Check weights - if corrupted, skip this forward but keep gradient
        if not torch.isfinite(self.fc1.weight).all():
            print(f"[PROJ] fc1.weight has NaN")
        
        x = self.fc1(x)
        x = self.act(x)
        
        if not torch.isfinite(self.fc2.weight).all():
            print(f"[PROJ] fc2.weight has NaN")
        
        x = self.fc2(x)
        
        # Safe L2 normalize - handle zero vectors
        norm = x.norm(dim=-1, keepdim=True).clamp(min=1e-6)
        x = x / norm
        
        # Replace any NaN with zeros while keeping gradient
        if not torch.isfinite(x).all():
            print(f"[PROJ] Output has NaN after normalize")
            x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
        
        return x


## Dataset with Batch Video Decoding

Using decord's batch API and pre-building the full index.

In [7]:
from decord import VideoReader, cpu
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='decord')


def build_window_index(video_dir: Path, cfg: Config) -> list[tuple[Path, int, int]]:
    """Build index of (video_path, start_frame, total_frames) tuples."""
    video_exts = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
    video_files = sorted([f for f in Path(cfg.video_dir).iterdir() if f.suffix.lower() in video_exts])
    
    if cfg.num_videos:
        video_files = video_files[:cfg.num_videos]
    
    index = []
    for vpath in video_files:
        try:
            vr = VideoReader(str(vpath), ctx=cpu(0))
            total = len(vr)
            del vr
        except Exception:
            continue
        
        if total < cfg.num_frames:
            continue
        
        # Generate window start positions
        max_start = total - cfg.num_frames
        starts = list(range(0, max_start + 1, cfg.window_stride))[:cfg.max_windows_per_video]
        
        for start in starts:
            index.append((vpath, start, total))
    
    return index


class VideoWindowDataset(Dataset):
    """Dataset that loads video windows and precomputed CLIP embeddings."""
    
    def __init__(
        self, 
        index: list[tuple[Path, int, int]], 
        processor: AutoVideoProcessor,
        clip_cache: dict[tuple[str, int], torch.Tensor],
        num_frames: int,
    ):
        self.index = index
        self.processor = processor
        self.clip_cache = clip_cache
        self.num_frames = num_frames
    
    def __len__(self) -> int:
        return len(self.index)
    
    def __getitem__(self, idx: int) -> dict:
        vpath, start, total = self.index[idx]
        
        # Load frames using decord's batch API
        vr = VideoReader(str(vpath), ctx=cpu(0))
        frame_indices = list(range(start, min(start + self.num_frames, total)))
        
        # Pad if needed
        while len(frame_indices) < self.num_frames:
            frame_indices.append(frame_indices[-1])
        
        # Batch decode - much faster than individual frame access
        frames = vr.get_batch(frame_indices).asnumpy()  # [T, H, W, C]
        del vr
        
        # Process for V-JEPA
        inputs = self.processor(frames, return_tensors="pt")
        pixel_values = inputs['pixel_values_videos'].squeeze(0)  # [T, C, H, W]
        
        # Get precomputed CLIP embedding
        clip_emb = self.clip_cache[(str(vpath), start)]
        
        return {
            'pixel_values': pixel_values,
            'clip_embedding': clip_emb,
        }


def collate_fn(batch: list[dict]) -> dict:
    """Stack batch items into tensors."""
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'clip_embedding': torch.stack([x['clip_embedding'] for x in batch]),
    }

## Precompute CLIP Embeddings

Batch processing with proper GPU utilization.

In [8]:
@torch.no_grad()
def precompute_clip_embeddings(
    index: list[tuple[Path, int, int]],
    clip_model: CLIPModel,
    clip_processor: CLIPProcessor,
    cfg: Config,
    batch_size: int = 32,
    num_workers: int = 4,
) -> dict[tuple[str, int], torch.Tensor]:
    """Precompute CLIP embeddings with threaded frame loading."""
    from collections import defaultdict
    from concurrent.futures import ThreadPoolExecutor
    from tqdm.auto import tqdm
    
    device = next(clip_model.parameters()).device
    clip_model.eval()
    
    # Group windows by video
    windows_by_video = defaultdict(list)
    for vpath, start, total in index:
        windows_by_video[vpath].append((start, total))
    
    def load_video_frames(vpath, windows):
        """Load all frames for all windows in a video."""
        vr = VideoReader(str(vpath), ctx=cpu(0))
        total_frames = len(vr)
        
        all_frames = []
        window_info = []
        
        for start, _ in windows:
            frame_indices = np.linspace(
                start,
                min(start + cfg.num_frames - 1, total_frames - 1),
                cfg.clip_sample_frames,
                dtype=int
            ).tolist()
            
            frames = vr.get_batch(frame_indices).asnumpy()
            all_frames.extend(list(frames))
            window_info.append((start, len(frames)))
        
        del vr
        return vpath, all_frames, window_info
    
    cache = {}
    video_items = list(windows_by_video.items())
    
    # Process videos with threaded loading
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        # Submit all loading tasks
        futures = [executor.submit(load_video_frames, vp, wins) for vp, wins in video_items]
        
        for future in tqdm(futures, desc="Precomputing CLIP"):
            vpath, all_frames, window_info = future.result()
            
            # Process frames through CLIP in batches
            all_embeds = []
            for i in range(0, len(all_frames), batch_size):
                batch_frames = all_frames[i:i + batch_size]
                inputs = clip_processor(images=batch_frames, return_tensors="pt")
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                with torch.amp.autocast('cuda', dtype=torch.float16):
                    embeds = clip_model.get_image_features(**inputs)
                all_embeds.append(embeds)
            
            all_embeds = torch.cat(all_embeds, dim=0)
            
            # Split back by window
            idx = 0
            for start, num_frames in window_info:
                window_embeds = all_embeds[idx:idx + num_frames]
                embed = window_embeds.mean(dim=0)
                embed = F.normalize(embed.float(), dim=-1)
                cache[(str(vpath), start)] = embed.cpu()
                idx += num_frames
    
    return cache

## Training Loop

Clean, minimal loop with proper PyTorch patterns.

In [9]:
from tqdm.auto import tqdm


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Cosine LR schedule with linear warmup."""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


class Trainer:
    """Minimal trainer for V-JEPA → CLIP alignment."""

    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.vjepa_device = torch.device(cfg.vjepa_device)
        self.clip_device = torch.device(cfg.clip_device)

        # Seed everything
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        # Initialize wandb
        if cfg.use_wandb:
            import wandb
            wandb.init(
                project=cfg.wandb_project,
                name=cfg.wandb_run_name,
                config=vars(cfg),
            )

        self._setup_models()
        self._setup_data()
        self._setup_training()

    def _setup_data(self):
        cfg = self.cfg
        video_dir = Path(cfg.video_dir)

        # Build window index
        print(f"Building window index from {video_dir}...")
        index = build_window_index(video_dir, cfg)
        print(f"Found {len(index)} windows")

        # CLIP cache filename based on config
        cache_key = f"{cfg.clip_model.replace('/', '_')}_{cfg.num_frames}f_{cfg.clip_sample_frames}cs_{cfg.window_stride}ws_{cfg.max_windows_per_video}mw_{cfg.num_videos or 'all'}v"
        cache_file = video_dir / f".clip_cache_{cache_key}.pt"

        # Try to load cached embeddings
        clip_cache = None
        if cache_file.exists():
            print(f"Loading cached CLIP embeddings from {cache_file}")
            clip_cache = torch.load(cache_file)
            # Verify cache has all needed keys
            missing = [k for k in [(str(v), s) for v, s, _ in index] if k not in clip_cache]
            if missing:
                print(f"Cache missing {len(missing)} entries, recomputing...")
                clip_cache = None

        # Precompute if needed
        if clip_cache is None:
            print("Precomputing CLIP embeddings...")
            clip_cache = precompute_clip_embeddings(
                index, self.clip, self.clip_processor, cfg
            )
            # Save cache
            print(f"Saving CLIP cache to {cache_file}")
            torch.save(clip_cache, cache_file)

        # Free CLIP model memory (we only need embeddings now)
        del self.clip
        torch.cuda.empty_cache()

        # Train/test split
        np.random.shuffle(index)
        split = int(len(index) * (1 - cfg.test_split))
        train_index, test_index = index[:split], index[split:]

        # Datasets
        self.train_dataset = VideoWindowDataset(
            train_index, self.vjepa_processor, clip_cache, cfg.num_frames
        )
        self.test_dataset = VideoWindowDataset(
            test_index, self.vjepa_processor, clip_cache, cfg.num_frames
        )

        # DataLoaders with proper workers
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=cfg.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            persistent_workers=cfg.num_workers > 0,
            prefetch_factor=2 if cfg.num_workers > 0 else None,
        )
        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=cfg.batch_size,
            shuffle=False,
            num_workers=cfg.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            persistent_workers=cfg.num_workers > 0,
            prefetch_factor=2 if cfg.num_workers > 0 else None,
        )

        print(f"Train: {len(self.train_dataset)}, Test: {len(self.test_dataset)}")

    def _setup_models(self):
        cfg = self.cfg
        dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16

        # V-JEPA with PEFT LoRA
        print("Loading V-JEPA...")
        vjepa = AutoModel.from_pretrained(
            cfg.vjepa_model,
            torch_dtype=dtype,
            attn_implementation="sdpa",
        )

        # Configure LoRA with PEFT
        lora_config = LoraConfig(
            r=cfg.lora_r,
            lora_alpha=cfg.lora_alpha,
            lora_dropout=cfg.lora_dropout,
            target_modules=cfg.lora_target_modules,
            bias="none",
            task_type=TaskType.FEATURE_EXTRACTION,
        )
        self.vjepa = get_peft_model(vjepa, lora_config)
        self.vjepa.to(self.vjepa_device)
        self.vjepa.print_trainable_parameters()

        # Gradient checkpointing (with correct settings for PEFT)
        if cfg.gradient_checkpointing:
            self.vjepa.enable_input_require_grads()
            self.vjepa.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": False}
            )
            print("Gradient checkpointing enabled")

        # CLIP (for precomputing embeddings)
        print("Loading CLIP...")
        self.clip = CLIPModel.from_pretrained(cfg.clip_model, torch_dtype=dtype)
        self.clip.to(self.clip_device)
        self.clip.eval()
        for p in self.clip.parameters():
            p.requires_grad_(False)

        # Projection head
        vjepa_dim = self.vjepa.config.hidden_size
        clip_dim = self.clip.config.projection_dim

        self.proj = ProjectionHead(vjepa_dim, clip_dim, cfg.proj_hidden_dim)
        self.proj.to(self.clip_device)

        if cfg.use_compile and hasattr(torch, 'compile'):
            print("Compiling projection head...")
            self.proj = torch.compile(self.proj, mode="reduce-overhead")

        # Loss
        self.criterion = InfoNCELoss(
            temperature=cfg.temperature,
            lambda_align=cfg.lambda_align,
        )

        # Memory bank
        self.memory_bank = CircularMemoryBank(
            capacity=cfg.memory_bank_size,
            embed_dim=clip_dim,
            device=self.clip_device,
        )

        # Processors
        self.vjepa_processor = AutoVideoProcessor.from_pretrained(cfg.vjepa_model, use_fast=True)
        self.clip_processor = CLIPProcessor.from_pretrained(cfg.clip_model, use_fast=True)

        # Watch models with wandb
        if cfg.use_wandb:
            wandb.watch(self.vjepa, log='gradients', log_freq=100)
            wandb.watch(self.proj, log='gradients', log_freq=100)

    def _setup_training(self):
        cfg = self.cfg

        # Collect trainable parameters and exclude norms/bias from decay
        lora_decay = []
        lora_no_decay = []
        proj_decay = []
        proj_no_decay = []

        for name, param in self.vjepa.named_parameters():
            if not param.requires_grad:
                continue
            target = lora_no_decay if (param.dim() == 1 or name.endswith('bias')) else lora_decay
            target.append(param)

        for name, param in self.proj.named_parameters():
            target = proj_no_decay if (param.dim() == 1 or name.endswith('bias')) else proj_decay
            target.append(param)

        params = [
            {'params': lora_decay, 'lr': cfg.lora_lr, 'weight_decay': cfg.weight_decay},
            {'params': lora_no_decay, 'lr': cfg.lora_lr, 'weight_decay': 0.0},
            {'params': proj_decay, 'lr': cfg.proj_lr, 'weight_decay': cfg.weight_decay},
            {'params': proj_no_decay, 'lr': cfg.proj_lr, 'weight_decay': 0.0},
        ]

        self.optimizer = torch.optim.AdamW(
            params,
            betas=(0.9, 0.999),
        )

        # LR schedule (use optimizer steps, not dataloader steps)
        steps_per_epoch = math.ceil(len(self.train_loader) / cfg.gradient_accum_steps)
        num_training_steps = steps_per_epoch * cfg.epochs
        num_warmup_steps = int(num_training_steps * cfg.warmup_ratio)

        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer, num_warmup_steps, num_training_steps
        )

        # AMP + scaler (enabled only for fp16)
        self.amp_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp_dtype == torch.float16)

    def train_epoch(self, epoch: int) -> dict:
        cfg = self.cfg
        self.vjepa.train()
        self.proj.train()

        total_loss = 0.0
        total_acc = 0.0
        total_cos = 0.0
        num_batches = 0
        steps_per_epoch = math.ceil(len(self.train_loader) / cfg.gradient_accum_steps)
        global_step = epoch * steps_per_epoch

        # EMA tracking (decay=0.99)
        ema_decay = 0.99
        ema_loss = None
        ema_acc = None
        ema_cos = None

        grad_norm = 0.0
        grad_max = 0.0

        self.optimizer.zero_grad(set_to_none=True)
        step_in_epoch = 0

        pbar = tqdm(self.train_loader, desc=f"Train {epoch+1}")
        for batch_idx, batch in enumerate(pbar):
            pixel_values = batch['pixel_values'].to(self.vjepa_device, non_blocking=True)
            clip_emb = batch['clip_embedding'].to(self.clip_device, non_blocking=True)

            current_step = global_step + step_in_epoch

            # Forward through V-JEPA (keep in autocast)
            with torch.amp.autocast('cuda', dtype=self.amp_dtype):
                outputs = self.vjepa(pixel_values_videos=pixel_values)
                hidden = outputs.last_hidden_state  # [B, T, D] in bf16/fp16

            # Pool OUTSIDE autocast and immediately convert to fp32 for stability
            pooled = hidden.float().mean(dim=1)  # [B, D] in fp32

            # Check for NaNs in V-JEPA output (skip bad batches)
            if not torch.isfinite(pooled).all():
                print(f"[SKIP] Non-finite V-JEPA outputs at batch {num_batches}")
                # Check and reset corrupted LoRA weights
                for name, param in self.vjepa.named_parameters():
                    if 'lora' in name.lower() and not torch.isfinite(param).all():
                        print(f"  Resetting {name}")
                        with torch.no_grad():
                            if 'lora_A' in name:
                                nn.init.kaiming_uniform_(param)
                            else:  # lora_B
                                nn.init.zeros_(param)
                continue

            # Log V-JEPA output magnitude on first batch for monitoring
            if num_batches == 0:
                with torch.no_grad():
                    print(f"[INFO] V-JEPA pooled stats: mean={pooled.mean().item():.3f}, "
                          f"std={pooled.std().item():.3f}, "
                          f"min={pooled.min().item():.3f}, "
                          f"max={pooled.max().item():.3f}")

            # Move to clip device and project
            pooled = pooled.to(self.clip_device)

            # Projection (fp32, no autocast)
            with torch.amp.autocast('cuda', enabled=False):
                projected = self.proj(pooled)

                # Check projection output
                if not torch.isfinite(projected).all():
                    print(f"[SKIP] Non-finite projection outputs at batch {num_batches}")
                    continue

                # Get memory bank contents
                memory = self.memory_bank.get_all()

                # Compute loss (returns dict with combined loss + components)
                outputs = self.criterion(projected, clip_emb, memory)

            # Check if loss is valid (skip if NaN)
            if not torch.isfinite(outputs['loss']):
                print(f"[SKIP] Non-finite loss")
                continue

            # Update memory bank (after loss computation)
            self.memory_bank.push(clip_emb)

            # Backward with grad scaling (support accumulation)
            loss = outputs['loss'] / cfg.gradient_accum_steps
            self.scaler.scale(loss).backward()

            accum_done = ((batch_idx + 1) % cfg.gradient_accum_steps == 0) or (batch_idx == len(self.train_loader) - 1)
            if accum_done:
                # During warmup, zero out LoRA gradients (train only projection head)
                if current_step < cfg.proj_warmup_steps:
                    for name, param in self.vjepa.named_parameters():
                        if param.grad is not None and ('lora' in name.lower()):
                            param.grad.zero_()

                # Unscale before inspecting/clipping gradients
                self.scaler.unscale_(self.optimizer)

                # Compute gradient stats BEFORE clipping
                all_params = list(self.vjepa.parameters()) + list(self.proj.parameters())
                grads = [p.grad for p in all_params if p.grad is not None]

                # Check for NaN gradients - skip this update if found
                has_nan_grad = any(not torch.isfinite(g).all() for g in grads)
                if has_nan_grad:
                    print(f"[SKIP] NaN gradients at batch {num_batches}")
                    self.optimizer.zero_grad(set_to_none=True)
                    self.scaler.update()
                    continue

                if grads:
                    grad_norms = [g.detach().norm(2).item() for g in grads]
                    grad_norm = (sum(n**2 for n in grad_norms)) ** 0.5
                    grad_max = max(g.detach().abs().max().item() for g in grads)
                else:
                    grad_norm = 0.0
                    grad_max = 0.0

                # Gradient clipping
                if cfg.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(all_params, cfg.grad_clip)

                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step()
                self.optimizer.zero_grad(set_to_none=True)
                step_in_epoch += 1

            # Track metrics
            loss_val = outputs['loss'].item()
            acc_val = outputs['acc'].item()
            cos_val = outputs['cos'].item()

            total_loss += loss_val
            total_acc += acc_val
            total_cos += cos_val
            num_batches += 1

            # Update EMAs
            if ema_loss is None:
                ema_loss = loss_val
                ema_acc = acc_val
                ema_cos = cos_val
            else:
                ema_loss = ema_decay * ema_loss + (1 - ema_decay) * loss_val
                ema_acc = ema_decay * ema_acc + (1 - ema_decay) * acc_val
                ema_cos = ema_decay * ema_cos + (1 - ema_decay) * cos_val

            # Log to wandb (per step)
            if cfg.use_wandb and num_batches % 10 == 0:
                wandb.log({
                    # Raw values
                    'train/loss': loss_val,
                    'train/contrastive': outputs['contrastive'].item(),
                    'train/alignment': outputs['alignment'].item(),
                    'train/acc': acc_val,
                    'train/cos_sim': cos_val,
                    'train/lr': self.scheduler.get_last_lr()[0],
                    'train/memory_size': self.memory_bank.size,
                    # EMA values
                    'train/ema_loss': ema_loss,
                    'train/ema_acc': ema_acc,
                    'train/ema_cos': ema_cos,
                    # Gradient stats
                    'grad/norm': grad_norm,
                    'grad/max': grad_max,
                })

            pbar.set_postfix({
                'loss': f"{ema_loss:.3f}",
                'acc': f"{ema_acc*100:.1f}%",
                'cos': f"{ema_cos:.3f}",
                'gnorm': f"{grad_norm:.1f}",
                'lr': f"{self.scheduler.get_last_lr()[0]:.2e}",
            })

        return {
            'loss': total_loss / num_batches,
            'acc': total_acc / num_batches,
            'cos': total_cos / num_batches,
        }

    @torch.no_grad()
    def evaluate(self, epoch: int) -> dict:
        self.vjepa.eval()
        self.proj.eval()

        # Reset memory bank for clean eval
        eval_memory = CircularMemoryBank(
            capacity=self.cfg.memory_bank_size,
            embed_dim=self.memory_bank.embed_dim,
            device=self.clip_device,
        )

        total_loss = 0.0
        total_acc = 0.0
        total_cos = 0.0
        num_batches = 0

        for batch in tqdm(self.test_loader, desc=f"Eval {epoch+1}"):
            pixel_values = batch['pixel_values'].to(self.vjepa_device, non_blocking=True)
            clip_emb = batch['clip_embedding'].to(self.clip_device, non_blocking=True)

            with torch.amp.autocast('cuda', dtype=self.amp_dtype):
                outputs = self.vjepa(pixel_values_videos=pixel_values)
                pooled = outputs.last_hidden_state.mean(dim=1)

            pooled = pooled.to(self.clip_device)

            with torch.amp.autocast('cuda', enabled=False):
                projected = self.proj(pooled.float())
                memory = eval_memory.get_all()
                outputs = self.criterion(projected, clip_emb, memory)

            eval_memory.push(clip_emb)

            total_loss += outputs['loss'].item()
            total_acc += outputs['acc'].item()
            total_cos += outputs['cos'].item()
            num_batches += 1

        return {
            'loss': total_loss / num_batches,
            'acc': total_acc / num_batches,
            'cos': total_cos / num_batches,
        }

    def save(self, path: str):
        """Save LoRA weights and projection head."""
        torch.save({
            'vjepa_lora': self.vjepa.state_dict(),
            'proj': self.proj.state_dict(),
            'config': self.cfg,
        }, path)
        print(f"Saved to {path}")

    def run(self) -> dict:
        """Run full training loop."""
        cfg = self.cfg
        history = {'train': [], 'test': []}
        best_loss = float('inf')

        for epoch in range(cfg.epochs):
            train_metrics = self.train_epoch(epoch)
            test_metrics = self.evaluate(epoch)

            history['train'].append(train_metrics)
            history['test'].append(test_metrics)

            # Log epoch-level metrics to wandb
            if cfg.use_wandb:
                wandb.log({
                    'epoch': epoch + 1,
                    'train/epoch_loss': train_metrics['loss'],
                    'train/epoch_acc': train_metrics['acc'],
                    'train/epoch_cos': train_metrics['cos'],
                    'test/loss': test_metrics['loss'],
                    'test/acc': test_metrics['acc'],
                    'test/cos': test_metrics['cos'],
                })

            print(
                f"Epoch {epoch+1}/{cfg.epochs}: "
                f"train_loss={train_metrics['loss']:.4f} "
                f"test_loss={test_metrics['loss']:.4f} "
                f"test_acc={test_metrics['acc']*100:.1f}%"
            )

            if test_metrics['loss'] < best_loss:
                best_loss = test_metrics['loss']
                self.save('best_model.pt')
                if cfg.use_wandb:
                    wandb.run.summary['best_test_loss'] = best_loss

        # Finish wandb run
        if cfg.use_wandb:
            wandb.finish()

        return history


## Run Training

In [None]:
# Train
trainer = Trainer(cfg)
history = trainer.run()

[34m[1mwandb[0m: Currently logged in as: [33mc-daly[0m ([33mdefiant-duck[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loading V-JEPA...


`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!


trainable params: 1,400,832 || all params: 327,372,160 || trainable%: 0.4279
Gradient checkpointing enabled
Loading CLIP...
Building window index from videos...
Found 4992 windows
Precomputing CLIP embeddings...


Precomputing CLIP:   0%|          | 0/312 [00:00<?, ?it/s]

## Plot Results

In [None]:
import matplotlib.pyplot as plt

def plot_history(history):
    epochs = range(1, len(history['train']) + 1)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Loss
    axes[0].plot(epochs, [h['loss'] for h in history['train']], label='train')
    axes[0].plot(epochs, [h['loss'] for h in history['test']], label='test')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].set_title('Loss')
    
    # Accuracy
    axes[1].plot(epochs, [h['acc']*100 for h in history['train']], label='train')
    axes[1].plot(epochs, [h['acc']*100 for h in history['test']], label='test')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].set_title('Contrastive Accuracy')
    
    # Cosine similarity
    axes[2].plot(epochs, [h['cos'] for h in history['train']], label='train')
    axes[2].plot(epochs, [h['cos'] for h in history['test']], label='test')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Cosine Similarity')
    axes[2].legend()
    axes[2].set_title('Mean Cosine Similarity')
    
    plt.tight_layout()
    plt.show()

plot_history(history)