# Skate Physics Preserver - Colab T4

Single notebook for full pipeline: Tracking → Generation → Validation.

**Target:** Google Colab with T4 GPU (16GB VRAM).  
**Flow:** Extract masks/poses → Reskin via Wan 2.1 VACE → IoU validation.

## Cell 1: Setup (GPU, Drive, Clone Repo)

In [None]:
# Enable T4 GPU: Runtime → Change runtime type → GPU (T4)
import os
import shutil
import subprocess
import sys

# Mount Google Drive for model persistence
from google.colab import drive
drive.mount('/content/drive')

# Clone repo cleanly (safe on reruns)
REPO_URL = 'https://github.com/YOUR_USER/skate-physics-preserver.git'  # Replace with your repo URL
REPO_DIR = '/content/skate-physics-preserver'
BRANCH = 'main'

if os.path.isdir(f'{REPO_DIR}/.git'):
    print('Repo exists: hard-reset + clean + pull latest...')
    subprocess.run(['git', '-C', REPO_DIR, 'fetch', 'origin', BRANCH], check=True)
    subprocess.run(['git', '-C', REPO_DIR, 'reset', '--hard', f'origin/{BRANCH}'], check=True)
    subprocess.run(['git', '-C', REPO_DIR, 'clean', '-fdx'], check=True)
else:
    if os.path.exists(REPO_DIR):
        shutil.rmtree(REPO_DIR)
    print('Cloning repo...')
    subprocess.run(['git', 'clone', '--branch', BRANCH, '--single-branch', REPO_URL, REPO_DIR], check=True)

%cd /content/skate-physics-preserver

gpu_info = subprocess.run(
    ['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
    capture_output=True,
    text=True,
).stdout.strip()
print('Setup complete. GPU:', gpu_info)

## Cell 2: Install Dependencies

In [None]:
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q -r requirements.txt
!pip install -q 'git+https://github.com/facebookresearch/sam2.git'
!pip uninstall -y onnxruntime 2>/dev/null; pip install -q onnxruntime-gpu

print('Dependencies installed.')

## Cell 3: Download Models

In [None]:
import os
from pathlib import Path

MODELS_DIR = Path('/content/drive/MyDrive/skate-physics-models')
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# SAM 2.1 Hiera-Small (T4-friendly)
SAM_CKPT = MODELS_DIR / 'sam2.1_hiera_s.pt'
if not SAM_CKPT.exists():
    try:
        from huggingface_hub import hf_hub_download
        hf_hub_download(repo_id='facebook/sam2.1-hiera-small', filename='sam2.1_hiera_s.pt', local_dir=str(MODELS_DIR))
    except Exception:
        !wget -q -O {SAM_CKPT} 'https://huggingface.co/facebook/sam2.1-hiera-small/resolve/main/sam2.1_hiera_s.pt'
    print('Downloaded SAM 2.1 Hiera-Small')
else:
    print('SAM 2.1 checkpoint exists')

# Create checkpoints symlink in repo
os.makedirs('/content/skate-physics-preserver/checkpoints', exist_ok=True)
if not os.path.exists('/content/skate-physics-preserver/checkpoints/sam2.1_hiera_s.pt'):
    os.symlink(str(SAM_CKPT), '/content/skate-physics-preserver/checkpoints/sam2.1_hiera_s.pt')

# Wan 2.1 VACE Q4_K_S (10.6GB) - download to ComfyUI models when we install ComfyUI
print('Wan VACE model will be downloaded in ComfyUI bootstrap cell.')

## Cell 4: Tracking (extract_physics)

In [None]:
import sys
sys.path.insert(0, '/content/skate-physics-preserver')

from src.extract_physics import run_extraction, get_bbox_colab

# Upload your video or use a path
VIDEO_PATH = '/content/skate-physics-preserver/sample_input.mp4'  # Replace: upload video to Colab
OUTPUT_DIR = '/content/skate-physics-preserver/output'
FRAME_CAP = 50  # T4 limit

# Bbox: [x1, y1, x2, y2] around skateboard on frame 0
# Option A: Provide manually (e.g. from inspecting first frame)
BBOX = [100, 200, 300, 280]  # Example - adjust for your video

# Option B: Use get_bbox_colab to display frame and prompt
# BBOX = get_bbox_colab(VIDEO_PATH)

run_extraction(
    VIDEO_PATH,
    OUTPUT_DIR,
    BBOX,
    frame_load_cap=FRAME_CAP,
)

print('Tracking complete. Masks and poses in:', OUTPUT_DIR)

## Cell 5: ComfyUI Bootstrap + Generation

In [None]:
# Clone ComfyUI and install custom nodes
import os
import sys
import subprocess
import time

!git clone https://github.com/comfyanonymous/ComfyUI.git /content/ComfyUI 2>/dev/null || true
%cd /content/ComfyUI

# VideoHelperSuite
!git clone https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git custom_nodes/ComfyUI-VideoHelperSuite 2>/dev/null || true

# ComfyUI-GGUF for Wan VACE GGUF
!git clone https://github.com/city96/ComfyUI-GGUF.git custom_nodes/ComfyUI-GGUF 2>/dev/null || true

# WanVideoWrapper (if available)
!git clone https://github.com/kijai/ComfyUI-WanVideoWrapper.git custom_nodes/ComfyUI-WanVideoWrapper 2>/dev/null || true

# Download Wan VACE Q4_K_S and VAE
os.makedirs('models/unet', exist_ok=True)
os.makedirs('models/vae', exist_ok=True)
!wget -q -nc -P models/unet https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/resolve/main/Wan2.1-VACE-14B-Q4_K_S.gguf
!wget -q -nc -P models/vae https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors

# Start ComfyUI in background (lowvram for T4)
proc = subprocess.Popen(
    [sys.executable, 'main.py', '--listen', '0.0.0.0', '--port', '8188', '--lowvram'],
    cwd='/content/ComfyUI',
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL,
)
time.sleep(15)  # Wait for ComfyUI to start
print('ComfyUI started. Run generation in next cell.')

In [None]:
import sys
sys.path.insert(0, '/content/skate-physics-preserver')

from src.generate_reskin import run_generation

OUTPUT_DIR = '/content/skate-physics-preserver/output'
MASK_DIR = f'{OUTPUT_DIR}/mask_skateboard'
POSE_DIR = f'{OUTPUT_DIR}/pose_skater'
VIDEO_PATH = '/content/skate-physics-preserver/sample_input.mp4'

result = run_generation(
    server_addr='127.0.0.1:8188',
    source_video_path=VIDEO_PATH,
    mask_dir=MASK_DIR,
    pose_dir=POSE_DIR,
    prompt='cyberpunk drone, samurai warrior',
    negative_prompt='blur, distortion, artifacts',
    frame_load_cap=50,
)

print('Generated output:', result)

## Cell 6: Validation (evaluate_iou)

In [None]:
import sys
sys.path.insert(0, '/content/skate-physics-preserver')

from src.evaluate_iou import load_mask_sequence, calculate_iou
from src.tracking.skateboard_tracker import SkateboardTracker

MASK_DIR = '/content/skate-physics-preserver/output/mask_skateboard'
GENERATED_VIDEO = '/content/ComfyUI/output/skate_reskin_00001.mp4'  # Adjust to actual output path
BBOX = [100, 200, 300, 280]  # Same as tracking cell

gt_frames = load_mask_sequence(MASK_DIR)
tracker = SkateboardTracker(
    '/content/skate-physics-preserver/checkpoints/sam2.1_hiera_s.pt',
    '/content/skate-physics-preserver/configs/sam2.1/sam2.1_hiera_s.yaml',
    frame_load_cap=50,
)
tracker.init_video(GENERATED_VIDEO)
tracker.add_initial_prompt(0, BBOX)
gen_masks = list(tracker.propagate_yield())

failed = []
for i in range(min(len(gt_frames), len(gen_masks))):
    iou = calculate_iou(gt_frames[i], gen_masks[i])
    status = 'PASS' if iou > 0.90 else 'FAIL'
    if iou <= 0.90:
        failed.append((i, iou))
    print(f'Frame {i:04d}: IoU = {iou:.4f} [{status}]')

print('---')
if failed:
    print(f'FAILED: {len(failed)} frames below 0.90 IoU')
else:
    print('SUCCESS: Zero-Clipping Benchmark Passed.')