# Skate Physics Preserver — Colab Demo

End-to-end pipeline: **YouTube → Tracking → Generation → Validation**

**Target:** Google Colab with T4 GPU (16 GB VRAM)  
**Input:** YouTube video URL  
**Output:** Reskinned video with physics preservation (IoU > 0.90)

---

### How to run
1. **Runtime → Change runtime type → T4 GPU**
2. Run each cell in order (Shift+Enter)
3. Adjust `BBOX` in Step 4 after viewing the first frame

## Step 1: Environment Setup

In [None]:
# Enable T4 GPU: Runtime → Change runtime type → GPU (T4)
import os, subprocess, shutil, sys

# Mount Google Drive for model persistence
from google.colab import drive
drive.mount('/content/drive')

# ── Clone or update the repo ──
REPO_URL = 'https://github.com/nsystema/skate-physics-preserver.git'
REPO_DIR = '/content/skate-physics-preserver'
BRANCH   = 'main'

if os.path.isdir(f'{REPO_DIR}/src'):
    print('Repo directory already exists — skipping clone.')
else:
    if os.path.isdir(f'{REPO_DIR}/.git'):
        subprocess.run(['git', '-C', REPO_DIR, 'fetch', 'origin', BRANCH], check=True)
        subprocess.run(['git', '-C', REPO_DIR, 'reset', '--hard', f'origin/{BRANCH}'], check=True)
        subprocess.run(['git', '-C', REPO_DIR, 'clean', '-fdx'], check=True)
    else:
        if os.path.exists(REPO_DIR):
            shutil.rmtree(REPO_DIR)
        subprocess.run(['git', 'clone', '--branch', BRANCH, '--single-branch', REPO_URL, REPO_DIR], check=True)

%cd {REPO_DIR}

gpu_info = subprocess.run(
    ['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
    capture_output=True, text=True,
).stdout.strip()
print(f'Setup complete. GPU: {gpu_info}')

## Step 2: Install Dependencies

In [None]:
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q -r requirements.txt
!pip install -q 'git+https://github.com/facebookresearch/sam2.git'
!pip uninstall -y onnxruntime 2>/dev/null; pip install -q onnxruntime-gpu
!pip install -q yt-dlp

print('Dependencies installed.')

## Step 3: Download Models

In [None]:
import os
from pathlib import Path

DRIVE_MODELS = Path('/content/drive/MyDrive/skate-physics-models')
DRIVE_MODELS.mkdir(parents=True, exist_ok=True)

# ── SAM 2.1 Hiera-Small (checkpoint) ──
SAM_CKPT = DRIVE_MODELS / 'sam2.1_hiera_s.pt'
if not SAM_CKPT.exists():
    try:
        from huggingface_hub import hf_hub_download
        hf_hub_download(repo_id='facebook/sam2.1-hiera-small',
                        filename='sam2.1_hiera_s.pt', local_dir=str(DRIVE_MODELS))
    except Exception:
        !wget -q -O {SAM_CKPT} 'https://huggingface.co/facebook/sam2.1-hiera-small/resolve/main/sam2.1_hiera_s.pt'
    print('Downloaded SAM 2.1 Hiera-Small')
else:
    print('SAM 2.1 checkpoint already cached in Drive')

# Symlink into repo checkpoints/
os.makedirs('/content/skate-physics-preserver/checkpoints', exist_ok=True)
sam_link = '/content/skate-physics-preserver/checkpoints/sam2.1_hiera_s.pt'
if not os.path.exists(sam_link):
    os.symlink(str(SAM_CKPT), sam_link)

print('Models ready. Wan VACE + CLIP will download in the ComfyUI step.')

## Step 4: Download Input Video & Configure BBox

Run the cell below to download the YouTube video.  
Then **adjust `BBOX`** in the following cell so the green rectangle covers the skateboard.

In [None]:
import subprocess, cv2, os
from PIL import Image
from IPython.display import display

# ═══════════════════════════════════════════════════════════
#  INPUT VIDEO — change the URL to use a different clip
# ═══════════════════════════════════════════════════════════
VIDEO_URL  = 'https://www.youtube.com/shorts/wPpXDBMk8-E'
VIDEO_PATH = '/content/skate-physics-preserver/input_video.mp4'
# ═══════════════════════════════════════════════════════════

if not os.path.exists(VIDEO_PATH):
    print(f'Downloading: {VIDEO_URL}')
    subprocess.run([
        'yt-dlp',
        '-f', 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
        '--merge-output-format', 'mp4',
        '-o', VIDEO_PATH,
        VIDEO_URL,
    ], check=True)
    print(f'Downloaded → {VIDEO_PATH}')
else:
    print(f'Video already exists: {VIDEO_PATH}')

# Show first frame for bbox reference
cap = cv2.VideoCapture(VIDEO_PATH)
ret, frame = cap.read()
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()

if ret:
    h, w = frame.shape[:2]
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    display(Image.fromarray(frame_rgb))
    print(f'\nVideo: {w} x {h}, {total_frames} frames, {fps:.1f} fps')
    print('Look at the frame above and set BBOX in the next cell.')

In [None]:
# ═══════════════════════════════════════════════════════════
#  BOUNDING BOX — [x1, y1, x2, y2] around the SKATEBOARD
#  x: left→right,  y: top→bottom
#  Adjust these after viewing the first frame above.
# ═══════════════════════════════════════════════════════════
BBOX      = [200, 600, 450, 680]   # ← ADJUST for your video
FRAME_CAP = 50                     # Max frames (T4 safe limit)
OUTPUT_DIR = '/content/skate-physics-preserver/output'
# ═══════════════════════════════════════════════════════════

# Draw bbox on first frame so you can verify visually
import cv2, numpy as np
from PIL import Image
from IPython.display import display

cap = cv2.VideoCapture(VIDEO_PATH)
ret, frame = cap.read()
cap.release()

if ret:
    vis = frame.copy()
    cv2.rectangle(vis, (BBOX[0], BBOX[1]), (BBOX[2], BBOX[3]), (0, 255, 0), 2)
    cv2.putText(vis, 'Skateboard BBOX', (BBOX[0], BBOX[1] - 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    display(Image.fromarray(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)))
    print(f'BBOX = {BBOX}')
    print('If the green box does NOT cover the skateboard, change BBOX above and re-run this cell.')

## Step 5: Extract Tracking Data (Masks + Poses)

In [None]:
import sys, os
sys.path.insert(0, '/content/skate-physics-preserver')

from src.extract_physics import run_extraction

result = run_extraction(
    video_path=VIDEO_PATH,
    output_dir=OUTPUT_DIR,
    bbox=BBOX,
    frame_load_cap=FRAME_CAP,
)

# Preview outputs
mask_files = sorted(os.listdir(result['masks_dir']))
pose_files = sorted(os.listdir(result['poses_dir']))
print(f'\nMasks: {len(mask_files)} frames  → {result["masks_dir"]}')
print(f'Poses: {len(pose_files)} frames  → {result["poses_dir"]}')

# Show sample mask + pose
import cv2
from PIL import Image
from IPython.display import display

if mask_files:
    m = cv2.imread(os.path.join(result['masks_dir'], mask_files[0]), cv2.IMREAD_GRAYSCALE)
    print(f'\nSample mask — {mask_files[0]}:')
    display(Image.fromarray(m))
if pose_files:
    p = cv2.imread(os.path.join(result['poses_dir'], pose_files[0]))
    print(f'Sample pose — {pose_files[0]}:')
    display(Image.fromarray(cv2.cvtColor(p, cv2.COLOR_BGR2RGB)))

## Step 6: ComfyUI Setup & Generation

This step:
1. Clones ComfyUI + required custom nodes
2. Downloads Wan 2.1 VACE (GGUF), VAE, and CLIP text encoder
3. Starts ComfyUI in `--lowvram` mode
4. Queues the V2V generation workflow

In [None]:
import os, sys, subprocess, time, urllib.request

# Clone ComfyUI
!git clone https://github.com/comfyanonymous/ComfyUI.git /content/ComfyUI 2>/dev/null || true

# Custom nodes
!git clone https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git \
    /content/ComfyUI/custom_nodes/ComfyUI-VideoHelperSuite 2>/dev/null || true
!git clone https://github.com/city96/ComfyUI-GGUF.git \
    /content/ComfyUI/custom_nodes/ComfyUI-GGUF 2>/dev/null || true
!git clone https://github.com/kijai/ComfyUI-WanVideoWrapper.git \
    /content/ComfyUI/custom_nodes/ComfyUI-WanVideoWrapper 2>/dev/null || true

# Download models (Wan VACE GGUF + VAE + CLIP text encoder)
os.makedirs('/content/ComfyUI/models/unet', exist_ok=True)
os.makedirs('/content/ComfyUI/models/vae',  exist_ok=True)
os.makedirs('/content/ComfyUI/models/clip', exist_ok=True)

!wget -q -nc -P /content/ComfyUI/models/unet \
    https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/resolve/main/Wan2.1-VACE-14B-Q4_K_S.gguf
!wget -q -nc -P /content/ComfyUI/models/vae \
    https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors
!wget -q -nc -P /content/ComfyUI/models/clip \
    https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors

# Start ComfyUI (--lowvram for T4)
proc = subprocess.Popen(
    [sys.executable, 'main.py', '--listen', '0.0.0.0', '--port', '8188', '--lowvram'],
    cwd='/content/ComfyUI',
    stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
)

# Health-check: poll until ComfyUI is responsive
print('Waiting for ComfyUI to start...')
for attempt in range(90):  # up to 3 min
    time.sleep(2)
    try:
        urllib.request.urlopen('http://127.0.0.1:8188/system_stats')
        print(f'ComfyUI ready after ~{(attempt + 1) * 2}s')
        break
    except Exception:
        if attempt % 10 == 9:
            print(f'  Still loading... ({(attempt + 1) * 2}s)')
else:
    print('WARNING: ComfyUI may not have started. Check /content/ComfyUI logs.')

In [None]:
import sys, os
sys.path.insert(0, '/content/skate-physics-preserver')

from src.generate_reskin import run_generation

MASK_DIR = os.path.join(OUTPUT_DIR, 'mask_skateboard')
POSE_DIR = os.path.join(OUTPUT_DIR, 'pose_skater')

generated_filename = run_generation(
    server_addr='127.0.0.1:8188',
    source_video_path=VIDEO_PATH,
    mask_dir=MASK_DIR,
    pose_dir=POSE_DIR,
    prompt='cyberpunk drone, samurai warrior',
    negative_prompt='blur, distortion, artifacts',
    frame_load_cap=FRAME_CAP,
)

if generated_filename:
    GENERATED_VIDEO = f'/content/ComfyUI/output/{generated_filename}'
    print(f'Generated video: {GENERATED_VIDEO}')
else:
    GENERATED_VIDEO = None
    print('WARNING: Generation returned no output filename.')

## Step 7: Validation (IoU Benchmark)

Re-tracks the skateboard in the **generated** video using SAM 2.1,  
then compares those masks to the **original** masks frame-by-frame.  
Passing threshold: **IoU > 0.90** on every frame.

In [None]:
import sys, os, cv2
sys.path.insert(0, '/content/skate-physics-preserver')

from src.evaluate_iou import load_mask_sequence, calculate_iou
from src.tracking.skateboard_tracker import SkateboardTracker

MASK_DIR = os.path.join(OUTPUT_DIR, 'mask_skateboard')

if GENERATED_VIDEO is None or not os.path.exists(str(GENERATED_VIDEO)):
    print('No generated video found — skipping validation.')
else:
    # 1. Load ground-truth masks
    gt_masks = load_mask_sequence(MASK_DIR)
    print(f'Ground truth: {len(gt_masks)} frames')

    # 2. Re-track skateboard in the generated video
    tracker = SkateboardTracker(
        '/content/skate-physics-preserver/checkpoints/sam2.1_hiera_s.pt',
        '/content/skate-physics-preserver/configs/sam2.1/sam2.1_hiera_s.yaml',
        frame_load_cap=FRAME_CAP,
    )
    tracker.init_video(GENERATED_VIDEO)
    tracker.add_initial_prompt(0, BBOX)
    gen_masks = list(tracker.propagate_yield())
    print(f'Generated:    {len(gen_masks)} frames')

    # 3. Per-frame IoU comparison
    n = min(len(gt_masks), len(gen_masks))
    failed = []
    for i in range(n):
        gt, pred = gt_masks[i], gen_masks[i]
        if pred.shape != gt.shape:
            pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]),
                              interpolation=cv2.INTER_NEAREST)
        iou = calculate_iou(gt, pred)
        tag = 'PASS' if iou > 0.90 else 'FAIL'
        if iou <= 0.90:
            failed.append((i, iou))
        print(f'Frame {i:04d}: IoU = {iou:.4f}  [{tag}]')

    print('---')
    if failed:
        print(f'FAILED: {len(failed)}/{n} frames below 0.90 IoU')
    else:
        print(f'SUCCESS: Zero-Clipping Benchmark PASSED — all {n} frames above 0.90 IoU')