# Skate Physics Preserver - Pipeline Demo

**Human-Object Relational Mapping for V2V Synthesis**

This notebook walks through the three-stage pipeline:
1. **Extract** - DWPose skeleton + SAM 2.1 object mask
2. **Generate** - Headless ComfyUI with Wan 2.1 VACE
3. **Validate** - Frame-by-frame IoU (Zero-Clipping Benchmark)

Optimized for RTX 3070 8GB VRAM.

## 0. Setup & Dependency Check

In [None]:
import sys, os
os.chdir(os.path.dirname(os.path.abspath('__file__')))
sys.path.insert(0, 'src')

# Verify GPU
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA:    {torch.cuda.is_available()}')
if torch.cuda.is_available():
    gpu = torch.cuda.get_device_name(0)
    vram = torch.cuda.get_device_properties(0).total_mem / (1024**3)
    print(f'GPU:     {gpu} ({vram:.1f} GB)')

## 1. Extract Tracking Data

Run DWPose + SAM 2.1 sequentially to extract skeleton poses and object masks.

In [None]:
# Configuration
VIDEO_PATH = 'input.mp4'          # Your source video
OUTPUT_DIR = 'output'              # Base output directory
SAM_CHECKPOINT = 'checkpoints/sam2.1_hiera_small.pt'
SAM_CONFIG = 'configs/sam2.1/sam2.1_hiera_s.yaml'

# Bounding box around the skateboard in frame 0 [x1, y1, x2, y2]
# Set to None for interactive selection
SKATEBOARD_BBOX = [120, 340, 280, 410]

In [None]:
import gc

masks_dir = os.path.join(OUTPUT_DIR, 'mask_skateboard')
poses_dir = os.path.join(OUTPUT_DIR, 'pose_skater')
json_dir  = os.path.join(OUTPUT_DIR, 'pose_json')

# --- Pass 1: DWPose ---
from tracking.skater_pose import SkaterPoseExtractor

pose = SkaterPoseExtractor(device='cuda', mode='balanced')
n_pose = pose.process_video(VIDEO_PATH, poses_dir, json_dir)
pose.cleanup()
torch.cuda.empty_cache()
gc.collect()
print(f'\nPose frames: {n_pose}')

In [None]:
# --- Pass 2: SAM 2.1 ---
from tracking.skateboard_tracker import SkateboardTracker

tracker = SkateboardTracker(SAM_CHECKPOINT, SAM_CONFIG, device='cuda')
tracker.init_video(VIDEO_PATH)
tracker.add_initial_prompt(frame_idx=0, bbox=SKATEBOARD_BBOX)
n_mask = tracker.propagate_and_save(masks_dir)
tracker.cleanup()
print(f'\nMask frames: {n_mask}')

In [None]:
# Visualize a sample frame
import cv2
import matplotlib.pyplot as plt
import numpy as np

sample_idx = 0
mask = cv2.imread(os.path.join(masks_dir, f'frame_{sample_idx:05d}.png'), 0)
pose_img = cv2.imread(os.path.join(poses_dir, f'frame_{sample_idx:05d}.png'))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].imshow(mask, cmap='gray')
axes[0].set_title('SAM 2.1 Object Mask')
axes[1].imshow(cv2.cvtColor(pose_img, cv2.COLOR_BGR2RGB))
axes[1].set_title('DWPose Skeleton')
for ax in axes: ax.axis('off')
plt.tight_layout()
plt.show()

## 2. Generate Reskin (ComfyUI)

Requires a running ComfyUI server: `python main.py --listen --lowvram`

In [None]:
from generate_reskin import ComfyOrchestrator

orch = ComfyOrchestrator(server_addr='127.0.0.1:8188')

# Check server
if orch.check_server():
    output_file = orch.execute_v2v(
        workflow_path='workflows/vace_template.json',
        source_video_path=VIDEO_PATH,
        masks_dir=masks_dir,
        poses_dir=poses_dir,
        positive_prompt='cyberpunk samurai riding a neon hoverboard, cinematic',
        negative_prompt='blurry, distorted, low quality, deformed',
        output_dir='output/generated',
    )
    print(f'Output: {output_file}')
else:
    print('ComfyUI server not running. Start with: python main.py --listen --lowvram')

## 3. Validate (IoU)

Re-track the generated video and compute frame-by-frame IoU.

In [None]:
from evaluate_iou import load_mask_sequence, extract_generated_masks, run_validation, print_report

GENERATED_VIDEO = 'output/generated/output.mp4'  # adjust filename as needed

# Load ground truth
gt_masks = load_mask_sequence(masks_dir)

# Reverse-track generated video
gen_masks = extract_generated_masks(
    video_path=GENERATED_VIDEO,
    bbox=SKATEBOARD_BBOX,
    sam_checkpoint=SAM_CHECKPOINT,
    sam_config=SAM_CONFIG,
    device='cuda',
)

# Validate
results = run_validation(gt_masks, gen_masks, threshold=0.90, verbose=False)
print_report(results)

In [None]:
# Plot IoU over time
if results['all_scores']:
    plt.figure(figsize=(12, 4))
    plt.plot(results['all_scores'], 'b-', linewidth=0.8)
    plt.axhline(y=0.90, color='r', linestyle='--', label='Threshold (0.90)')
    plt.fill_between(range(len(results['all_scores'])), results['all_scores'], 0.90,
                     where=[s < 0.90 for s in results['all_scores']],
                     color='red', alpha=0.3, label='Failed frames')
    plt.xlabel('Frame')
    plt.ylabel('IoU')
    plt.title('Zero-Clipping Benchmark: IoU Over Time')
    plt.legend()
    plt.ylim(0, 1.05)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()