# ChatCut Colab Server

**Run cells 1-7 in order, then copy the URL into Premiere Pro.**

| Cell | What it does |
|------|-------------|
| 1 | Install dependencies |
| 2 | Imports & config |
| 3 | Tracking functions |
| 4 | Effect parser |
| 5 | Keyframe planner |
| 6 | Renderers |
| 7 | **START SERVER** |

In [1]:
#@title 0. API Keys (Optional)
# Set your Gemini API key for this runtime.\n
import os
GOOGLE_API_KEY = ""  # @param {type:"string"}
if GOOGLE_API_KEY:
    os.environ['GOOGLE_API_KEY'] = YOUR_GEMINI_API_KEY
    print('‚úÖ GOOGLE_API_KEY set for this runtime')
else:
    print('‚ÑπÔ∏è Using existing GOOGLE_API_KEY env var (if set)')


‚ÑπÔ∏è Using existing GOOGLE_API_KEY env var (if set)


In [2]:
#@title 1. Install Dependencies
!pip install -q ultralytics opencv-python moviepy numpy pandas tqdm scipy fastapi uvicorn python-multipart lapx pyngrok nest-asyncio transformers pillow google-genai

In [3]:
#@title 2. Imports & Config
from __future__ import annotations
import json, math, os, re, tempfile, traceback
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
from scipy.signal import savgol_filter
from moviepy.editor import VideoFileClip
import torch
from ultralytics import YOLO
from PIL import Image

# SigLIP for semantic object matching
from transformers import CLIPModel, CLIPProcessor

#============================================================
# CONFIGURATION
#============================================================
TEST_MODE = False  # Set True for fast 480p rendering, False for production quality
#============================================================

BASE_DIR = Path('/content')
EXPORT_DIR = BASE_DIR / 'exports'
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

DEFAULT_MODELS = {'det': 'yolo11n.pt', 'seg': 'yolo11n-seg.pt'}
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"‚úÖ Device: {DEVICE.upper()}")
print(f"‚úÖ Exports: {EXPORT_DIR}")
print(f"üß™ TEST_MODE: {TEST_MODE}" + (" (480p, ultrafast)" if TEST_MODE else " (full quality)"))
if DEVICE == 'cpu':
    print('‚ö†Ô∏è No GPU - enable T4 runtime for better performance')

# Load SigLIP model for semantic object selection
print("üì• Loading CLIP model (ViT-L/14)...")
SIGLIP_MODEL = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
SIGLIP_PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
SIGLIP_MODEL.eval()
print("‚úÖ CLIP loaded (semantic object matching enabled)")

# Optional: Gemini 2.5 Flash Lite for visual reranking
GEMINI_CLIENT = None
GEMINI_TYPES = None
GEMINI_MODEL_ID = "gemini-2.5-flash"

# Hardcoded API key (preferred here due to VS Code Colab env limitations).
# Replace "YOUR_GEMINI_API_KEY" with your actual key. Leave empty to fall back to env var.
GEMINI_API_KEY = "YOUR_GEMINI_API_KEY"

try:
    from google import genai
    from google.genai import types as genai_types
    # Allow user to paste key after the placeholder or set env vars.
    raw_key = (GEMINI_API_KEY or "").strip()
    placeholder = "YOUR_GEMINI_API_KEY"
    if raw_key.startswith(placeholder):
        raw_key = raw_key[len(placeholder):].strip()
    if not raw_key:
        # Fallback to common env var names from AI Studio docs
        raw_key = (os.environ.get("GEMINI_API_KEY") or
                   os.environ.get("GOOGLE_GENERATIVE_AI_API_KEY") or
                   os.environ.get("GOOGLE_API_KEY"))
    if not raw_key:
        raise RuntimeError("Gemini API key not set. Set GEMINI_API_KEY or GEMINI_API_KEY/GOOGLE_GENERATIVE_AI_API_KEY env vars.")
    GEMINI_CLIENT = genai.Client(api_key=raw_key)
    GEMINI_TYPES = genai_types
    print(f"‚úÖ Gemini client initialized ({GEMINI_MODEL_ID})")
except Exception as e:
    print("‚ÑπÔ∏è Gemini not available (optional):", e)


‚úÖ Device: CUDA
‚úÖ Exports: /content/exports
üß™ TEST_MODE: False (full quality)
üì• Loading CLIP model (ViT-L/14)...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


‚úÖ CLIP loaded (semantic object matching enabled)
‚úÖ Gemini client initialized (gemini-2.5-flash)


In [4]:
#@title 3. Tracking Functions
MODEL_CACHE = {}
CLASS_NAME_CACHE = {}

def load_model(use_seg=False):
    name = DEFAULT_MODELS['seg'] if use_seg else DEFAULT_MODELS['det']
    if name not in MODEL_CACHE:
        print(f'Loading {name}...')
        MODEL_CACHE[name] = YOLO(name)
    return MODEL_CACHE[name]

def _build_name_map(model):
    key = id(model)
    if key not in CLASS_NAME_CACHE:
        names = getattr(model, 'names', {}) or {}
        CLASS_NAME_CACHE[key] = {int(k): str(v) for k, v in names.items()} if isinstance(names, dict) else {i: str(v) for i, v in enumerate(names)}
    return CLASS_NAME_CACHE[key]

def _normalize_label(text): return text.strip().lower()

def encode_mask(mask):
    mask = (mask > 0.5).astype(np.uint8).flatten(order='F')
    counts, last, run = [], 0, 0
    for v in mask:
        if v == last: run += 1
        else: counts.append(run); run = 1; last = v
    counts.append(run)
    if mask.size and mask[0] == 1: counts = [0] + counts
    return {'size': [int(mask.shape[0]), 1], 'counts': counts}

def decode_mask(rle):
    h, w = rle['size']
    vals = []
    cur = 0
    for c in rle['counts']:
        vals.extend([cur] * c)
        cur = 1 - cur
    return np.array(vals, dtype=np.uint8).reshape((h, w), order='F')

def detect_and_track(video_path, use_seg=True, frame_stride=1, conf=0.25, iou=0.45, imgsz=960, save_json=False):
    cap = cv2.VideoCapture(str(video_path))
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    w, h = int(cap.get(3)), int(cap.get(4))
    duration = total / fps if fps else 0
    cap.release()

    model = load_model(use_seg)
    name_map = _build_name_map(model)
    
    print(f'Tracking {video_path} @ {fps:.1f}fps | {w}x{h}')
    stream = model.track(source=str(video_path), imgsz=imgsz, tracker='bytetrack.yaml', stream=True,
                         conf=conf, iou=iou, vid_stride=frame_stride, device=DEVICE, verbose=False, persist=True)
    
    frames, cursor = [], 0
    for result in tqdm(stream, desc='Tracking', total=math.ceil(total/frame_stride)):
        dets = []
        if result.boxes is not None and result.boxes.id is not None:
            ids = result.boxes.id.int().cpu().tolist()
            xyxy = result.boxes.xyxy.cpu().tolist()
            confs = result.boxes.conf.cpu().tolist()
            clss = result.boxes.cls.int().cpu().tolist()
            masks = result.masks.data.cpu().numpy() if use_seg and result.masks else None
            for i, tid in enumerate(ids):
                dets.append({'id': int(tid), 'cls': name_map.get(clss[i], str(clss[i])),
                            'conf': float(confs[i]), 'bbox_xyxy': [float(v) for v in xyxy[i]],
                            'mask_rle': encode_mask(masks[i]) if masks is not None else None})
        frames.append({'frame_index': cursor, 't': cursor/fps, 'detections': dets})
        cursor += frame_stride
    
    return {'video_path': str(video_path), 'fps': fps, 'size': [w, h], 'duration': duration, 'frames': frames}

# Pre-download models to avoid delay on first request
print("üì• Pre-downloading YOLO models...")
_ = load_model(use_seg=False)  # Detection model
_ = load_model(use_seg=True)   # Segmentation model
print("‚úÖ Models cached and ready!")
print('‚úÖ Tracking functions loaded')

üì• Pre-downloading YOLO models...
Loading yolo11n.pt...
Loading yolo11n-seg.pt...
‚úÖ Models cached and ready!
‚úÖ Tracking functions loaded


In [5]:
#@title 4. Effect Parser (Gemini only)
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any

EFFECT_DEFAULTS = {
    'ZoomFollow': {'margin': 0.10}, 'Spotlight': {'strength': 0.7, 'feather': 45},
    'BlurBackground': {'ksize': 21}, 'PixelateObject': {'block': 20},
    'AutoReframe': {'aspect': '9:16'}, 'Callout': {'label': 'object'},
    'PiPMagnifier': {'scale': 1.5, 'radius': 120}, 'PathOverlay': {}
}

EFFECT_KEYWORDS = {
    'ZoomFollow': ['zoom', 'punch in', 'follow'], 'Spotlight': ['spotlight', 'highlight'],
    'BlurBackground': ['blur background', 'background blur'], 'PixelateObject': ['pixelate'],
    'AutoReframe': ['reframe', 'vertical'], 'Callout': ['callout', 'label'],
    'PiPMagnifier': ['pip', 'magnifier'], 'PathOverlay': ['path', 'trajectory']
}

@dataclass
class EffectCommand:
    effect: str
    object: str
    t_in: float
    t_out: float
    params: Dict[str, Any] = field(default_factory=dict)
    hints: List[str] = field(default_factory=list)
    ordinal: Optional[int] = None


def parse_nl_to_dsl(cmd: str, duration: float):
    """Parse natural language to EffectCommand list using Gemini only."""
    if GEMINI_CLIENT is None or GEMINI_TYPES is None:
        raise ValueError("Gemini client not initialized; set GEMINI_API_KEY.")

    schema_hint = """
You are ChatCut's command parser. Return ONLY JSON.
Format: {\"effects\": [ { \"effect\": one of [ZoomFollow,Spotlight,BlurBackground,PixelateObject,AutoReframe,Callout,PathOverlay],
\"object\": string, \"spatial_hint\": leftmost/rightmost/center/null, \"ordinal\": integer or null,
\"t_in\": float, \"t_out\": float, \"label\": string or null } ] }.
Ensure 0 <= t_in < t_out <= duration. No extra text or markdown.
"""

    prompt = f"""
Video duration: {duration:.3f} seconds.
User command: '{cmd}'.
Return ONLY the JSON object, nothing else.
"""

    content = GEMINI_TYPES.Content(parts=[
        GEMINI_TYPES.Part(text=schema_hint + "\n\n" + prompt)
    ])

    import json as _json
    import re as _re

    # Prefer JSON mime type when available
    try:
        config = GEMINI_TYPES.GenerateContentConfig(response_mime_type="application/json")
    except Exception:
        config = None

    try:
        if config:
            resp = GEMINI_CLIENT.models.generate_content(model=GEMINI_MODEL_ID, contents=content, config=config)
        else:
            resp = GEMINI_CLIENT.models.generate_content(model=GEMINI_MODEL_ID, contents=content)
    except Exception as e:
        raise ValueError(f"Gemini request failed: {e}")

    def _extract_text(resp_obj):
        raw_text = getattr(resp_obj, 'text', '') or ''
        raw_text = raw_text.strip()
        if raw_text:
            return raw_text
        try:
            parts = []
            for cand in getattr(resp_obj, 'candidates', []) or []:
                c_content = getattr(cand, 'content', None)
                if c_content is None:
                    continue
                for part in getattr(c_content, 'parts', []) or []:
                    t = getattr(part, 'text', None)
                    if t:
                        parts.append(t)
            return "\n".join(parts).strip()
        except Exception:
            return ''

    raw = _extract_text(resp)
    if not raw:
        raise ValueError("Gemini returned empty response")

    if raw.startswith('```'):
        stripped = raw.strip('`')
        if stripped.lower().startswith('json'):
            stripped = stripped[4:].lstrip()
        raw = stripped

    try:
        data = _json.loads(raw)
    except Exception:
        m = _re.search(r"{.*}", raw, _re.S)
        if not m:
            raise ValueError("Gemini response not JSON")
        data = _json.loads(m.group(0))

    effects = data.get('effects') or []
    if not effects:
        raise ValueError("Gemini returned no effects")

    commands = []
    for eff in effects:
        effect_name = eff.get('effect')
        if effect_name not in EFFECT_DEFAULTS:
            continue
        obj = eff.get('object') or 'person'
        spatial = eff.get('spatial_hint') or None
        ordinal = eff.get('ordinal')
        if isinstance(ordinal, float):
            ordinal = int(ordinal)
        if not isinstance(ordinal, int):
            ordinal = None
        t_in = float(max(0.0, min(duration, eff.get('t_in', 0.0))))
        t_out = float(max(t_in + 1e-3, min(duration, eff.get('t_out', duration))))
        label = eff.get('label')

        hints = []
        if spatial in ('leftmost', 'rightmost', 'center'):
            hints.append(spatial)

        params = dict(EFFECT_DEFAULTS.get(effect_name, {}))
        if effect_name == 'Callout' and label:
            params['label'] = label

        commands.append(EffectCommand(
            effect=effect_name,
            object=obj,
            t_in=t_in,
            t_out=t_out,
            params=params,
            hints=hints,
            ordinal=ordinal,
        ))

    if not commands:
        raise ValueError("Gemini produced no valid commands")
    return commands

print('‚úÖ Effect parser loaded (Gemini only)')
print('   ‚Ä¢ Available effects:', list(EFFECT_KEYWORDS.keys()))


‚úÖ Effect parser loaded (Gemini only)
   ‚Ä¢ Available effects: ['ZoomFollow', 'Spotlight', 'BlurBackground', 'PixelateObject', 'AutoReframe', 'Callout', 'PiPMagnifier', 'PathOverlay']


In [6]:
#@title 5. Keyframe Planner (Fixed SigLIP + Callout + Smoothing)
tracks_df_cached = None

# Minimum continuity thresholds
MIN_CONTINUITY = 0.5          # General: at least 50% of frames
MIN_CALLOUT_CONTINUITY = 0.7  # Callout needs higher stability (70%)

def tracks_to_df(tracks):
    recs = []
    for f in tracks['frames']:
        for d in f['detections']:
            x1,y1,x2,y2 = d['bbox_xyxy']
            recs.append({'t': f['t'], 'frame_index': f['frame_index'], 'id': d['id'],
                        'cls': _normalize_label(d['cls']), 'conf': d['conf'],
                        'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'mask_rle': d.get('mask_rle')})
    return pd.DataFrame(recs)

def get_tracks_df(tracks):
    global tracks_df_cached
    if tracks_df_cached is None: tracks_df_cached = tracks_to_df(tracks)
    return tracks_df_cached

def smooth(vals, wl=25, po=2):
    """Ultra-smooth motion for stable camera tracking."""
    vals = np.asarray(vals)
    if len(vals) < wl:
        out = [vals[0]]
        for v in vals[1:]:
            out.append(0.05*v + 0.95*out[-1])
        return np.array(out)
    
    smoothed = savgol_filter(vals, wl, po)
    wl2 = min(15, len(smoothed))
    if wl2 >= 5:
        smoothed = savgol_filter(smoothed, wl2, po)
    out = [smoothed[0]]
    for v in smoothed[1:]:
        out.append(0.2*v + 0.8*out[-1])
    return np.array(out)

def clean_object_for_siglip(obj):
    """Transform parsed object into clean SigLIP query."""
    stop_phrases = [
        'and put', 'and add', 'and apply', 'put a', 'add a', 'put the',
        'callout', 'zoom', 'spotlight', 'blur', 'pixelate', 'label',
        'with the text', 'saying', 'with text', 'the text', 'text saying',
        'in on', 'on the', 'on a', 'the the', 'and a', 'and the',
    ]
    
    result = obj.lower().strip()
    for phrase in stop_phrases:
        result = result.replace(phrase, ' ')
    
    result = ' '.join(result.split())
    
    if not result or result in ['and', 'the', 'a', 'an', 'on', 'in', 'to', 'with']:
        return 'a person'
    
    if not result.startswith(('a ', 'an ', 'the ')):
        if result[0] in 'aeiou':
            result = f'an {result}'
        else:
            result = f'a {result}'
    
    return result


def build_clip_query(obj_clean, effect=None):
    """Expand a clean object phrase into a richer CLIP text query."""
    base = obj_clean.strip()
    lower = base.lower()
    details = []

    if any(k in lower for k in ['newspaper', 'paper', 'magazine']):
        details.append('holding a newspaper in front of their body')
    if any(k in lower for k in ['man', 'woman', 'guy', 'person']):
        details.append('main person in the shot')

    if effect == 'ZoomFollow':
        details.append('framed from the waist up, clearly visible')
    elif effect == 'Callout':
        details.append('good candidate for a label, unobstructed')

    details.append('not on a TV screen, not a projected image')
    details.append('not a tiny background figure, not far away')

    full = base
    if details:
        full = base + ', ' + ', '.join(details)
    return full


def gemini_rerank_tracks(video_path, win_df, candidate_df, user_description, top_k=3):
    """Use Gemini 2.5 Flash to rerank top candidate tracks visually."""
    if GEMINI_CLIENT is None or GEMINI_TYPES is None:
        return None
    if candidate_df.empty:
        return None

    print(f"\n   {'‚îÄ'*50}")
    print(f"   ü§ñ GEMINI VISUAL RERANK")
    print(f"   {'‚îÄ'*50}")
    print(f"      Query: '{user_description}'")
    print(f"      Candidates: {min(top_k, len(candidate_df))}")

    try:
        import cv2
        import numpy as np
        import re as _re

        if 'final_score' in candidate_df.columns:
            cand = candidate_df.nlargest(min(top_k, len(candidate_df)), 'final_score')
        else:
            cand = candidate_df.copy()

        crops = []
        tids = []

        cap = cv2.VideoCapture(video_path)
        for _, row in cand.iterrows():
            tid = int(row['id'])
            t_data = win_df[win_df['id'] == tid]
            if t_data.empty:
                continue
            frame_idx = int(t_data['frame_index'].median())
            x1 = int(t_data['x1'].median())
            y1 = int(t_data['y1'].median())
            x2 = int(t_data['x2'].median())
            y2 = int(t_data['y2'].median())

            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret:
                continue
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            crop = frame_rgb[max(0, y1):max(0, min(frame_rgb.shape[0], y2)),
                             max(0, x1):max(0, min(frame_rgb.shape[1], x2))]
            if crop.size == 0:
                continue
            ok, buf = cv2.imencode('.jpg', cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
            if not ok:
                continue
            crops.append(buf.tobytes())
            tids.append(tid)
            print(f"      ‚Ä¢ Track {tid}: frame {frame_idx}, bbox [{x1},{y1},{x2},{y2}]")

        cap.release()

        if len(crops) < 2:
            print(f"      ‚ö†Ô∏è  Not enough crops for rerank ({len(crops)})")
            return None

        prompt = (
            f"You will see {len(crops)} images in order. "
            "Image 1 is the first, image 2 the second, etc. "
            f"Which image best matches this description: '{user_description}'? "
            "Reply with just the number (1, 2, or 3)."
        )

        parts = [GEMINI_TYPES.Part(text=prompt)]
        for data in crops:
            parts.append(GEMINI_TYPES.Part(inline_data=GEMINI_TYPES.Blob(data=data, mime_type='image/jpeg')))

        content = GEMINI_TYPES.Content(parts=parts)
        
        import time as _time
        start_time = _time.time()
        resp = GEMINI_CLIENT.models.generate_content(model=GEMINI_MODEL_ID, contents=content)
        api_time = _time.time() - start_time
        
        text = getattr(resp, 'text', '') or ''
        print(f"      ‚ö° API time: {api_time:.2f}s")
        print(f"      üì® Response: '{text.strip()}'")
        
        m = _re.search(r'[1-9]', text)
        if not m:
            print(f"      ‚ö†Ô∏è  Could not parse selection from response")
            return None
        idx = int(m.group(0)) - 1
        if idx < 0 or idx >= len(tids):
            print(f"      ‚ö†Ô∏è  Invalid index {idx+1} (have {len(tids)} candidates)")
            return None
        
        selected_tid = int(tids[idx])
        print(f"      ‚úÖ Gemini selected: Track {selected_tid} (image {idx+1})")
        print(f"   {'‚îÄ'*50}")
        return selected_tid
        
    except Exception as e:
        print(f"      ‚ùå Gemini rerank failed: {e}")
        print(f"   {'‚îÄ'*50}")
        return None


def choose_track_by_layout(df, t_in, t_out, frame_size, hints=None, ordinal=None, fps=30.0, min_continuity=MIN_CONTINUITY):
    """Deterministic selection based on left/right/center and ordinal."""
    hints = hints or []
    W, H = frame_size

    print(f"\n   {'‚îÄ'*50}")
    print(f"   üìê LAYOUT-BASED SELECTION")
    print(f"   {'‚îÄ'*50}")
    print(f"      Hints: {hints}")
    print(f"      Ordinal: {ordinal}")

    win = df[(df['t'] >= t_in) & (df['t'] <= t_out)]
    if win.empty:
        raise ValueError(f'No detections found in time {t_in:.1f}s-{t_out:.1f}s')

    win = win[win['cls'] == 'person']
    if win.empty:
        raise ValueError('No person detections found in the selected time window')

    track_stats = []
    time_window = t_out - t_in
    expected_frames = max(time_window * fps, 1)

    for tid in win['id'].unique():
        t_data = win[win['id'] == tid]
        frame_count = len(t_data)
        continuity = min(frame_count / expected_frames, 1.0)
        x1 = float(t_data['x1'].median())
        y1 = float(t_data['y1'].median())
        x2 = float(t_data['x2'].median())
        y2 = float(t_data['y2'].median())
        track_stats.append({
            'id': int(tid),
            'continuity': continuity,
            'avg_x': (x1 + x2) / 2,
            'avg_y': (y1 + y2) / 2,
            'avg_width': x2 - x1,
            'avg_height': y2 - y1,
        })

    stats_df = pd.DataFrame(track_stats)
    if stats_df.empty:
        raise ValueError('No valid person tracks for layout selection')

    print(f"      Found {len(stats_df)} person tracks")

    valid_df = stats_df[stats_df['continuity'] >= min_continuity]
    if valid_df.empty:
        print(f"      ‚ö†Ô∏è  No tracks with ‚â•{min_continuity*100:.0f}% continuity, using all")
        valid_df = stats_df
    else:
        print(f"      Filtered to {len(valid_df)} tracks (‚â•{min_continuity*100:.0f}% continuity)")

    sort_left = True
    if 'rightmost' in hints:
        sort_left = False

    if ordinal is not None:
        ordered = valid_df.sort_values('avg_x', ascending=sort_left).reset_index(drop=True)
        idx = max(0, min(int(ordinal), len(ordered) - 1))
        best = ordered.iloc[idx]
        print(f"      ‚úÖ Ordinal selection: #{ordinal} from {'left' if sort_left else 'right'}")
        print(f"      ‚Üí Track {int(best['id'])} at x={best['avg_x']:.0f}")
        print(f"   {'‚îÄ'*50}")
        return int(best['id'])

    if 'leftmost' in hints or 'rightmost' in hints:
        ordered = valid_df.sort_values('avg_x', ascending=sort_left)
        best = ordered.iloc[0]
        print(f"      ‚úÖ Spatial selection: {'leftmost' if sort_left else 'rightmost'}")
        print(f"      ‚Üí Track {int(best['id'])} at x={best['avg_x']:.0f}")
        print(f"   {'‚îÄ'*50}")
        return int(best['id'])

    if 'center' in hints:
        cx, cy = W / 2, H / 2
        valid_df = valid_df.copy()
        valid_df['dist_to_center'] = np.sqrt((valid_df['avg_x'] - cx)**2 + (valid_df['avg_y'] - cy)**2)
        best = valid_df.nsmallest(1, 'dist_to_center').iloc[0]
        print(f"      ‚úÖ Center selection: closest to ({cx:.0f}, {cy:.0f})")
        print(f"      ‚Üí Track {int(best['id'])} (dist={best['dist_to_center']:.0f}px)")
        print(f"   {'‚îÄ'*50}")
        return int(best['id'])

    valid_df = valid_df.copy()
    valid_df['size'] = valid_df['avg_width'] * valid_df['avg_height']
    cx, cy = W / 2, H / 2
    valid_df['dist_to_center'] = np.sqrt((valid_df['avg_x'] - cx)**2 + (valid_df['avg_y'] - cy)**2)
    valid_df['score'] = valid_df['size'] / (1 + valid_df['dist_to_center'])
    best = valid_df.nlargest(1, 'score').iloc[0]
    print(f"      ‚úÖ Fallback: largest central person")
    print(f"      ‚Üí Track {int(best['id'])} (size={best['size']:.0f}px¬≤)")
    print(f"   {'‚îÄ'*50}")
    return int(best['id'])




def choose_track_gemini_primary(df, user_description, t_in, t_out, video_path, frame_size, fps=30.0, max_tracks=5):
    """Use Gemini to pick the best matching person track.
    Returns track id or None on failure.
    """
    if GEMINI_CLIENT is None or GEMINI_TYPES is None:
        return None

    # Time window and person filter
    win = df[(df['t'] >= t_in) & (df['t'] <= t_out)]
    win = win[win['cls'] == 'person']
    if win.empty:
        return None

    W, H = frame_size
    time_window = t_out - t_in
    expected_frames = max(time_window * fps, 1)

    # Score tracks by size * continuity to pick top candidates
    stats = []
    for tid in win['id'].unique():
        t_data = win[win['id'] == tid]
        frame_count = len(t_data)
        continuity = min(frame_count / expected_frames, 1.0)
        x1 = float(t_data['x1'].median())
        y1 = float(t_data['y1'].median())
        x2 = float(t_data['x2'].median())
        y2 = float(t_data['y2'].median())
        size = max(1.0, (x2 - x1) * (y2 - y1))
        stats.append({'id': int(tid), 'continuity': continuity, 'size': size,
                      'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2,
                      'frame_index': int(t_data['frame_index'].median())})

    if not stats:
        return None

    import cv2
    import numpy as np

    # Pick top candidates
    stats = sorted(stats, key=lambda s: s['size'] * s['continuity'], reverse=True)[:max_tracks]

    cap = cv2.VideoCapture(video_path)
    parts = []
    tids = []
    for idx_s, s in enumerate(stats, start=1):
        cap.set(cv2.CAP_PROP_POS_FRAMES, s['frame_index'])
        ret, frame = cap.read()
        if not ret:
            continue
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        x1,y1,x2,y2 = map(int, [s['x1'], s['y1'], s['x2'], s['y2']])
        crop = frame_rgb[max(0,y1):min(frame_rgb.shape[0],y2), max(0,x1):min(frame_rgb.shape[1],x2)]
        if crop.size == 0:
            continue
        ok, buf = cv2.imencode('.jpg', cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
        if not ok:
            continue
        tids.append(s['id'])
        parts.append(GEMINI_TYPES.Part(inline_data=GEMINI_TYPES.Blob(data=buf.tobytes(), mime_type='image/jpeg')))
    cap.release()

    if len(parts) < 1:
        return None

    prompt = (
        f"You will see {len(parts)} images. Image 1 is first, image 2 is second, etc. "
        f"Which image best matches this description: '{user_description}'? "
        "Reply with just the number (1, 2, ...)."
    )

    content = GEMINI_TYPES.Content(parts=[GEMINI_TYPES.Part(text=prompt)] + parts)

    try:
        resp = GEMINI_CLIENT.models.generate_content(model=GEMINI_MODEL_ID, contents=content)
        txt = getattr(resp, 'text', '') or ''
    except Exception:
        return None

    import re as _re
    m = _re.search(r'[1-9]', txt)
    if not m:
        return None
    idx_choice = int(m.group(0)) - 1
    if idx_choice < 0 or idx_choice >= len(tids):
        return None
    return tids[idx_choice]

def choose_track_id_siglip(df, user_description, t_in, t_out, video_path, frame_size, 
                           hints=None, fps=30.0, min_continuity=MIN_CONTINUITY):
    """Use SigLIP/CLIP to semantically match detections to user description."""
    hints = hints or []
    W, H = frame_size
    
    print(f"\n   {'‚îÄ'*50}")
    print(f"   üîç CLIP SEMANTIC MATCHING")
    print(f"   {'‚îÄ'*50}")
    print(f"      Query: '{user_description}'")
    print(f"      Time window: {t_in:.2f}s ‚Üí {t_out:.2f}s")
    
    win = df[(df['t'] >= t_in) & (df['t'] <= t_out)]
    if win.empty:
        raise ValueError(f'No detections found in time {t_in:.1f}s-{t_out:.1f}s')
    
    win = win[win['cls'] == 'person']
    if win.empty:
        raise ValueError('No person detections found in the selected time window')
    
    track_ids = win['id'].unique()
    print(f"      Found {len(track_ids)} person tracks to score")
    
    track_stats = []
    time_window = t_out - t_in
    expected_frames = max(time_window * fps, 1)
    
    for tid in track_ids:
        t_data = win[win['id'] == tid]
        frame_count = len(t_data)
        continuity = min(frame_count / expected_frames, 1.0)
        
        x1 = int(t_data['x1'].median())
        y1 = int(t_data['y1'].median())
        x2 = int(t_data['x2'].median())
        y2 = int(t_data['y2'].median())

        frame_indices = sorted(t_data['frame_index'].unique())
        if len(frame_indices) >= 3:
            sample_idxs = [frame_indices[0], frame_indices[len(frame_indices)//2], frame_indices[-1]]
        else:
            sample_idxs = frame_indices

        scores = []
        for frame_idx in sample_idxs:
            cap = cv2.VideoCapture(video_path)
            cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_idx))
            ret, frame = cap.read()
            cap.release()
            if not ret:
                continue
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            crop = frame_rgb[max(0,y1):min(frame_rgb.shape[0],y2),
                            max(0,x1):min(frame_rgb.shape[1],x2)]
            if crop.size == 0:
                continue

            crop_pil = Image.fromarray(crop)
            inputs = SIGLIP_PROCESSOR(
                images=crop_pil,
                text=[user_description],
                return_tensors="pt",
                padding="max_length"
            ).to(DEVICE)

            with torch.no_grad():
                outputs = SIGLIP_MODEL(**inputs)

            logit = float(outputs.logits_per_image.item())
            score = 1.0 / (1.0 + math.exp(-logit / 4.0))
            scores.append(score)

        if not scores:
            continue

        siglip_score = float(sum(scores) / len(scores))
        cls = t_data['cls'].iloc[0]
        avg_conf = t_data['conf'].mean()

        stats = {
            'id': tid,
            'cls': cls,
            'siglip_score': siglip_score,
            'continuity': continuity,
            'avg_conf': avg_conf,
            'count': frame_count,
            'avg_x': (x1 + x2) / 2,
            'avg_y': (y1 + y2) / 2,
            'avg_width': x2 - x1,
            'avg_height': y2 - y1,
        }
        track_stats.append(stats)
    
    if not track_stats:
        raise ValueError("No valid detections found")
    
    stats_df = pd.DataFrame(track_stats)
    
    valid_df = stats_df[stats_df['continuity'] >= min_continuity]
    if valid_df.empty:
        print(f"      ‚ö†Ô∏è  No tracks with ‚â•{min_continuity*100:.0f}% continuity, using top 3")
        valid_df = stats_df.nlargest(3, 'continuity')
    else:
        print(f"      Filtered to {len(valid_df)} tracks (‚â•{min_continuity*100:.0f}% continuity)")
    
    if 'leftmost' in hints:
        valid_df = valid_df.nsmallest(min(3, len(valid_df)), 'avg_x')
        print(f"      Applied 'leftmost' filter: {len(valid_df)} candidates")
    
    if 'rightmost' in hints:
        valid_df = valid_df.nlargest(min(3, len(valid_df)), 'avg_x')
        print(f"      Applied 'rightmost' filter: {len(valid_df)} candidates")
    
    if 'center' in hints:
        cx, cy = W / 2, H / 2
        valid_df = valid_df.copy()
        valid_df['dist_to_center'] = np.sqrt((valid_df['avg_x'] - cx)**2 + (valid_df['avg_y'] - cy)**2)
        valid_df = valid_df.nsmallest(min(3, len(valid_df)), 'dist_to_center')
        print(f"      Applied 'center' filter: {len(valid_df)} candidates")
    
    valid_df = valid_df.copy()
    cx, cy = W / 2, H / 2
    valid_df['dist_to_center'] = np.sqrt((valid_df['avg_x'] - cx)**2 + (valid_df['avg_y'] - cy)**2)
    valid_df['dist_norm'] = valid_df['dist_to_center'] / np.sqrt(cx**2 + cy**2)
    valid_df['center_norm'] = 1 / (1 + valid_df['dist_norm'])
    valid_df['size_norm'] = np.clip((valid_df['avg_width'] * valid_df['avg_height']) / (W * H) * 5.0, 0.5, 2.0)
    valid_df['final_score'] = (valid_df['siglip_score'] ** 2) * np.sqrt(valid_df['continuity']) * valid_df['size_norm'] * valid_df['center_norm']

    print(f"\n      üìä CLIP SCORES (top {min(5, len(valid_df))}):")
    for _, row in valid_df.nlargest(min(5, len(valid_df)), 'final_score').iterrows():
        print(f"         Track {int(row['id']):3d}: CLIP={row['siglip_score']:.3f}, cont={row['continuity']:.2f}, final={row['final_score']:.4f}")

    max_siglip = float(valid_df['siglip_score'].max())

    if max_siglip < 0.25:
        best = valid_df.loc[valid_df['dist_to_center'].idxmin()]
        print(f"\n      ‚ö†Ô∏è  Low CLIP confidence ({max_siglip:.3f} < 0.25)")
        print(f"      ‚Üí Fallback to most central: Track {int(best['id'])}")
    else:
        best = valid_df.loc[valid_df['final_score'].idxmax()]
        print(f"\n      ‚úÖ CLIP selection: Track {int(best['id'])} (score={best['final_score']:.4f})")

        if GEMINI_CLIENT is not None and max_siglip < 0.6 and len(valid_df) > 1:
            print(f"      üîÑ CLIP confidence middling ({max_siglip:.3f} < 0.6), trying Gemini rerank...")
            gem_tid = gemini_rerank_tracks(video_path, win, valid_df, user_description)
            if gem_tid is not None and gem_tid in valid_df['id'].values:
                best = valid_df[valid_df['id'] == gem_tid].iloc[0]
                print(f"      ü§ñ Gemini override: Track {int(best['id'])}")

    print(f"   {'‚îÄ'*50}")
    return int(best['id'])


def choose_track_for_callout(df, user_description, t_in, t_out, video_path, frame_size, 
                              hints=None, fps=30.0):
    """Specialized track selection for Callout effect (needs high continuity)."""
    hints = hints or []
    W, H = frame_size
    
    print(f"\n   {'‚îÄ'*50}")
    print(f"   üè∑Ô∏è  CALLOUT TRACK SELECTION")
    print(f"   {'‚îÄ'*50}")
    print(f"      Query: '{user_description}'")
    print(f"      Required continuity: ‚â•{MIN_CALLOUT_CONTINUITY*100:.0f}%")
    
    win = df[(df['t'] >= t_in) & (df['t'] <= t_out)]
    if win.empty:
        raise ValueError(f'No detections found in time {t_in:.1f}s-{t_out:.1f}s')
    
    track_ids = win['id'].unique()
    print(f"      Found {len(track_ids)} tracks")
    
    track_stats = []
    time_window = t_out - t_in
    expected_frames = max(time_window * fps, 1)
    
    for tid in track_ids:
        t_data = win[win['id'] == tid]
        frame_count = len(t_data)
        continuity = min(frame_count / expected_frames, 1.0)
        
        x1 = int(t_data['x1'].median())
        y1 = int(t_data['y1'].median())
        x2 = int(t_data['x2'].median())
        y2 = int(t_data['y2'].median())
        frame_idx = int(t_data['frame_index'].median())
        
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        cap.release()
        if not ret:
            continue
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        crop = frame_rgb[max(0,y1):min(frame_rgb.shape[0],y2),
                        max(0,x1):min(frame_rgb.shape[1],x2)]
        
        if crop.size == 0:
            continue
        
        crop_pil = Image.fromarray(crop)
        inputs = SIGLIP_PROCESSOR(
            images=crop_pil, text=[user_description],
            return_tensors="pt", padding="max_length"
        ).to(DEVICE)
        
        with torch.no_grad():
            outputs = SIGLIP_MODEL(**inputs)
        
        logit = float(outputs.logits_per_image.item())
        siglip_score = 1.0 / (1.0 + math.exp(-logit / 4.0))
        
        stats = {
            'id': tid,
            'cls': t_data['cls'].iloc[0],
            'siglip_score': siglip_score,
            'continuity': continuity,
            'avg_conf': t_data['conf'].mean(),
            'count': frame_count,
            'avg_x': (x1 + x2) / 2,
            'avg_y': (y1 + y2) / 2,
            'avg_height': y2 - y1,
        }
        track_stats.append(stats)
    
    if not track_stats:
        raise ValueError("No valid detections")
    
    stats_df = pd.DataFrame(track_stats)
    
    valid_df = stats_df[stats_df['continuity'] >= MIN_CALLOUT_CONTINUITY]
    if valid_df.empty:
        print(f"      ‚ö†Ô∏è  No tracks with ‚â•{MIN_CALLOUT_CONTINUITY*100:.0f}% continuity")
        valid_df = stats_df.nlargest(1, 'continuity')
    else:
        print(f"      {len(valid_df)} high-continuity tracks")
    
    if 'center' in hints:
        cx, cy = W / 2, H / 2
        valid_df = valid_df.copy()
        valid_df['dist_to_center'] = np.sqrt((valid_df['avg_x'] - cx)**2 + (valid_df['avg_y'] - cy)**2)
        valid_df = valid_df.nsmallest(min(3, len(valid_df)), 'dist_to_center')
    
    valid_df = valid_df.copy()
    valid_df['final_score'] = (
        (valid_df['siglip_score'] ** 2) *
        (valid_df['continuity'] ** 2) *
        (valid_df['avg_height'] / 100)
    )

    print(f"\n      üìä CALLOUT SCORES:")
    for _, row in valid_df.nlargest(min(3, len(valid_df)), 'final_score').iterrows():
        print(f"         Track {int(row['id']):3d}: CLIP={row['siglip_score']:.3f}, cont={row['continuity']:.2f}, h={row['avg_height']:.0f}px")

    max_siglip = float(valid_df['siglip_score'].max())

    if max_siglip < 0.25:
        best = valid_df.loc[valid_df['continuity'].idxmax()]
        print(f"\n      ‚ö†Ô∏è  Low CLIP confidence, using most stable track")
        print(f"      ‚Üí Track {int(best['id'])} (continuity={best['continuity']:.2f})")
    else:
        best = valid_df.loc[valid_df['final_score'].idxmax()]
        print(f"\n      ‚úÖ Selected: Track {int(best['id'])} (cont={best['continuity']:.2f})")

        if GEMINI_CLIENT is not None and max_siglip < 0.6 and len(valid_df) > 1:
            gem_tid = gemini_rerank_tracks(video_path, win, valid_df, user_description)
            if gem_tid is not None and gem_tid in valid_df['id'].values:
                best = valid_df[valid_df['id'] == gem_tid].iloc[0]
                print(f"      ü§ñ Gemini override: Track {int(best['id'])}")

    print(f"   {'‚îÄ'*50}")
    return int(best['id'])


def plan_effect(cmd, tracks):
    """Plan effect using SigLIP semantic matching."""
    df = get_tracks_df(tracks)
    fps = tracks.get('fps', 30.0)
    
    print(f"\n{'='*60}")
    print(f"üìä KEYFRAME PLANNER: {cmd.effect}")
    print(f"{'='*60}")
    print(f"   Raw object: '{cmd.object}'")
    
    obj_clean = clean_object_for_siglip(cmd.object)
    clip_query = build_clip_query(obj_clean, cmd.effect)
    print(f"   Cleaned: '{obj_clean}'")
    print(f"   CLIP query: '{clip_query[:80]}{'...' if len(clip_query) > 80 else ''}'")

    hints = getattr(cmd, 'hints', None)
    ordinal = getattr(cmd, 'ordinal', None)
    use_layout = (ordinal is not None) or (hints and any(h in ['leftmost', 'rightmost'] for h in hints))

    if use_layout:
        print(f"   Selection method: LAYOUT (ordinal={ordinal}, hints={hints})")
        tid = choose_track_by_layout(
            df, cmd.t_in, cmd.t_out,
            tracks['size'],
            hints=hints,
            ordinal=ordinal,
            fps=fps
        )
    elif cmd.effect == 'Callout':
        print(f"   Selection method: CALLOUT-SPECIALIZED")
        tid = choose_track_for_callout(
            df, clip_query, cmd.t_in, cmd.t_out,
            tracks['video_path'], tracks['size'],
            hints=hints,
            fps=fps
        )
    else:
        print(f"   Selection method: GEMINI ‚Üí CLIP")
        tid = None
        if GEMINI_CLIENT is not None and GEMINI_TYPES is not None:
            try:
                tid = choose_track_gemini_primary(
                    df, clip_query, cmd.t_in, cmd.t_out,
                    tracks['video_path'], tracks['size'],
                    fps=fps
                )
                if tid is not None:
                    print(f"   ü§ñ Gemini selected track {tid}")
            except Exception as e:
                print(f"   ‚ö†Ô∏è Gemini primary selection failed: {e}")
        if tid is None:
            tid = choose_track_id_siglip(
                df, clip_query, cmd.t_in, cmd.t_out,
                tracks['video_path'], tracks['size'],
                hints=hints,
                fps=fps
            )
    
    win = df[(df['id'] == tid) & (df['t'] >= cmd.t_in) & (df['t'] <= cmd.t_out)]
    
    cx = smooth((win['x1'].values + win['x2'].values) / 2)
    cy = smooth((win['y1'].values + win['y2'].values) / 2)
    W, H = tracks['size']
    widths, heights = win['x2'].values - win['x1'].values, win['y2'].values - win['y1'].values
    margin = cmd.params.get('margin', 0.1)
    scale = smooth(np.maximum(widths/W, heights/H) * (1 + margin))
    
    timeline = [{'t': float(r.t), 'frame': int(r.frame_index),
                 'center': [float(cx[i]), float(cy[i])], 'scale': float(scale[i]),
                 'bbox': [float(r.x1), float(r.y1), float(r.x2), float(r.y2)], 'mask_rle': r.mask_rle}
                for i, r in enumerate(win.itertuples())]
    
    print(f"\n   üìç FINAL PLAN:")
    print(f"      Track ID: {tid}")
    print(f"      Keyframes: {len(timeline)}")
    print(f"      Time range: {cmd.t_in:.2f}s ‚Üí {cmd.t_out:.2f}s")
    print(f"      Params: {cmd.params}")
    print(f"{'='*60}\n")
    
    return {'effect': cmd.effect, 'object': cmd.object, 'track_id': int(tid),
            't_in': cmd.t_in, 't_out': cmd.t_out, 'timeline': timeline,
            'frame_size': tracks['size'], 'fps': fps,
            'video_path': tracks['video_path'], 'params': cmd.params}

print('‚úÖ Keyframe planner loaded (v5 - Enhanced Logging)')
print(f'   ‚Ä¢ Min continuity: {MIN_CONTINUITY*100:.0f}% (general), {MIN_CALLOUT_CONTINUITY*100:.0f}% (callout)')
print(f'   ‚Ä¢ CLIP model: ViT-L/14')
print(f'   ‚Ä¢ Gemini rerank: {"enabled" if GEMINI_CLIENT else "disabled"}')

‚úÖ Keyframe planner loaded (v5 - Enhanced Logging)
   ‚Ä¢ Min continuity: 50% (general), 70% (callout)
   ‚Ä¢ CLIP model: ViT-L/14
   ‚Ä¢ Gemini rerank: enabled


In [7]:
#@title 6. Renderers (with TEST_MODE support + Premiere Pro Compatibility)
import subprocess

def get_video_encoding_params(video_path):
    """Read source video FPS and calculate optimal encoding parameters."""
    try:
        result = subprocess.run([
            'ffprobe', '-v', 'quiet',
            '-select_streams', 'v:0',
            '-show_entries', 'stream=r_frame_rate',
            '-of', 'csv=p=0',
            str(video_path)
        ], capture_output=True, text=True)
        r_frame = result.stdout.strip()
        if '/' in r_frame:
            num, den = map(int, r_frame.split('/'))
            fps = num / den if den else 30.0
        else:
            fps = float(r_frame) if r_frame else 30.0
        gop_size = int(round(fps))  # 1 keyframe per second
        print(f"Source video: {fps:.2f}fps, GOP={gop_size}")
        return {'fps': fps, 'gop_size': gop_size, 'keyint_min': gop_size}
    except Exception as e:
        print(f"Warning: Could not read video params: {e}, using defaults")
        return {'fps': 30.0, 'gop_size': 30, 'keyint_min': 30}


def decode_mask_rle(rle, frame_shape):
    if not rle:
        return None
    mask = decode_mask(rle)
    if mask.shape != frame_shape:
        mask = cv2.resize(mask, (frame_shape[1], frame_shape[0]), interpolation=cv2.INTER_NEAREST)
    return mask


def timeline_sampler(timeline):
    times = np.array([item['t'] for item in timeline], dtype=np.float32)
    centers = np.array([item['center'] for item in timeline], dtype=np.float32)
    scales = np.array([item['scale'] for item in timeline], dtype=np.float32)
    bboxes = np.array([item['bbox'] for item in timeline], dtype=np.float32)
    masks = [item.get('mask_rle') for item in timeline]

    def sample(t):
        if t <= times[0]:
            return {'center': centers[0], 'scale': scales[0], 'bbox': bboxes[0], 'mask': masks[0]}
        if t >= times[-1]:
            return {'center': centers[-1], 'scale': scales[-1], 'bbox': bboxes[-1], 'mask': masks[-1]}
        idx = np.searchsorted(times, t, side='right')
        i0 = max(idx - 1, 0)
        i1 = min(idx, len(times) - 1)
        span = (times[i1] - times[i0]) or 1e-6
        alpha = (t - times[i0]) / span
        center = centers[i0] * (1 - alpha) + centers[i1] * alpha
        scale = scales[i0] * (1 - alpha) + scales[i1] * alpha
        bbox = bboxes[i0] * (1 - alpha) + bboxes[i1] * alpha
        mask = masks[i0 if alpha <= 0.5 else i1]
        return {'center': center, 'scale': scale, 'bbox': bbox, 'mask': mask}

    return sample


def ensure_mask(state, frame_shape, feather=25):
    mask = decode_mask_rle(state.get('mask'), frame_shape)
    if mask is None:
        x1, y1, x2, y2 = state['bbox']
        temp = np.zeros(frame_shape, dtype=np.uint8)
        cv2.ellipse(
            temp,
            center=(int((x1 + x2) / 2), int((y1 + y2) / 2)),
            axes=(int(max((x2 - x1) / 2, 1)), int(max((y2 - y1) / 2, 1))),
            angle=0, startAngle=0, endAngle=360,
            color=255, thickness=-1
        )
        mask = temp
    if feather > 0:
        mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=feather)
    return np.clip(mask.astype(np.float32) / 255.0, 0, 1)[..., None]


def clamp_window(center, scale, frame_size):
    W, H = frame_size
    crop_w = max(W * scale, 64)
    crop_h = max(H * scale, 64)
    x1 = np.clip(center[0] - crop_w / 2, 0, W - crop_w)
    y1 = np.clip(center[1] - crop_h / 2, 0, H - crop_h)
    x2 = x1 + crop_w
    y2 = y1 + crop_h
    return int(x1), int(y1), int(x2), int(y2)


def run_moviepy(plan, frame_fn, output_path, codec='libx264'):
    """Render video with Premiere Pro compatible encoding."""
    clip = VideoFileClip(plan['video_path'])
    sampler = timeline_sampler(plan['timeline'])
    H = int(plan['frame_size'][1])
    W = int(plan['frame_size'][0])

    def processor(get_frame, t):
        frame = get_frame(t)
        if t < plan['t_in'] or t > plan['t_out']:
            return frame
        state = sampler(t)
        return frame_fn(frame, state, (H, W))

    processed = clip.fl(processor)
    
    # Get encoding params from source video
    enc_params = get_video_encoding_params(plan['video_path'])
    
    if TEST_MODE:
        # Fast preview mode
        processed = processed.resize(height=480)
        processed.write_videofile(
            str(output_path),
            codec=codec,
            audio=True,
            audio_codec='aac',
            fps=clip.fps,
            preset='ultrafast',
            logger=None
        )
    else:
        # Production mode with Premiere Pro compatibility
        ffmpeg_params = [
            '-pix_fmt', 'yuv420p',
            '-profile:v', 'high',
            '-level', '4.1',
            '-g', str(enc_params['gop_size']),
            '-keyint_min', str(enc_params['keyint_min']),
            '-bf', '2',
            '-movflags', '+faststart',
            # Premiere Pro compatibility flags
            '-vsync', 'cfr',                    # Force constant frame rate
            '-video_track_timescale', '30000',  # Standard time base
        ]
        processed.write_videofile(
            str(output_path),
            codec=codec,
            audio=True,
            audio_codec='aac',
            fps=clip.fps,
            ffmpeg_params=ffmpeg_params,
            logger=None
        )
    
    clip.close()
    processed.close()
    return output_path


def render_zoom_follow(plan, output):
    W, H = plan['frame_size']
    def fn(frame, state, shape):
        cx, cy = state['center']
        scale = np.clip(state['scale'], 0.2, 1.0)
        x1, y1, x2, y2 = clamp_window((cx, cy), scale, (W, H))
        cropped = frame[int(y1):int(y2), int(x1):int(x2)]
        return cv2.resize(cropped, (W, H), interpolation=cv2.INTER_CUBIC)
    return run_moviepy(plan, fn, output)


def render_spotlight(plan, output):
    feather = int(plan['params'].get('feather', 45))
    strength = plan['params'].get('strength', 0.7)
    def fn(frame, state, shape):
        mask = ensure_mask(state, shape, feather)
        dimmed = (frame * (1 - strength)).astype(np.uint8)
        return (frame * mask + dimmed * (1 - mask)).astype(np.uint8)
    return run_moviepy(plan, fn, output)


def render_blur_background(plan, output):
    ksize = int(plan['params'].get('ksize', 21))
    if ksize % 2 == 0:
        ksize += 1
    def fn(frame, state, shape):
        mask = ensure_mask(state, shape, 25)
        blurred = cv2.GaussianBlur(frame, (ksize, ksize), 0)
        return (frame * mask + blurred * (1 - mask)).astype(np.uint8)
    return run_moviepy(plan, fn, output)


def render_pixelate(plan, output):
    block = int(plan['params'].get('block', 20))
    def fn(frame, state, shape):
        mask = ensure_mask(state, shape, 5)
        h, w, _ = frame.shape
        small = cv2.resize(frame, (max(1, w // block), max(1, h // block)), interpolation=cv2.INTER_LINEAR)
        pixelated = cv2.resize(small, (w, h), interpolation=cv2.INTER_NEAREST)
        return (frame * (1 - mask) + pixelated * mask).astype(np.uint8)
    return run_moviepy(plan, fn, output)


def render_callout(plan, output):
    label = plan['params'].get('label', plan['object'])
    def fn(frame, state, shape):
        frame_out = frame.copy()
        cx, cy = map(int, state['center'])
        anchor = (np.clip(cx + 100, 0, frame.shape[1] - 1), np.clip(cy - 100, 0, frame.shape[0] - 1))
        cv2.line(frame_out, (cx, cy), anchor, (255, 255, 255), 2)
        cv2.circle(frame_out, (cx, cy), 6, (0, 255, 0), -1)
        box_w, box_h = 160, 60
        x1 = np.clip(anchor[0], 0, frame.shape[1] - box_w - 1)
        y1 = np.clip(anchor[1], 0, frame.shape[0] - box_h - 1)
        x2, y2 = x1 + box_w, y1 + box_h
        overlay = frame_out.copy()
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), -1)
        cv2.addWeighted(overlay, 0.65, frame_out, 0.35, 0, frame_out)
        cv2.rectangle(frame_out, (x1, y1), (x2, y2), (255, 255, 255), 2)
        cv2.putText(frame_out, label, (x1 + 12, y1 + 35), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2)
        return frame_out
    return run_moviepy(plan, fn, output)


def render_path(plan, output):
    points = np.array([item['center'] for item in plan['timeline']], dtype=np.int32)
    def fn(frame, state, shape):
        frame_out = frame.copy()
        cv2.polylines(frame_out, [points], False, (0, 255, 255), 4)
        return frame_out
    return run_moviepy(plan, fn, output)


RENDERERS = {
    'ZoomFollow': render_zoom_follow,
    'Spotlight': render_spotlight,
    'BlurBackground': render_blur_background,
    'PixelateObject': render_pixelate,
    'Callout': render_callout,
    'PathOverlay': render_path
}

print('‚úÖ Renderers loaded (v3 - Premiere Pro Compatible)')
print(f'   üß™ TEST_MODE: {TEST_MODE}', '‚Üí 480p, ultrafast' if TEST_MODE else '‚Üí full quality + Premiere Pro flags')
print(f'   Effects: {list(RENDERERS.keys())}')

‚úÖ Renderers loaded (v3 - Premiere Pro Compatible)
   üß™ TEST_MODE: False ‚Üí full quality + Premiere Pro flags
   Effects: ['ZoomFollow', 'Spotlight', 'BlurBackground', 'PixelateObject', 'Callout', 'PathOverlay']


In [None]:
#@title 7. START SERVER (Multi-Effect + Progress Polling)
# ‚ö†Ô∏è GET YOUR TOKEN: https://ngrok.com ‚Üí sign up ‚Üí Your Authtoken
# Kill any existing ngrok tunnels
try:
    from pyngrok import ngrok
    ngrok.kill()
except:
    pass
NGROK_TOKEN = "YOUR_NGROK_TOKEN"  # <-- PASTE YOUR TOKEN

#============================================================
from pyngrok import ngrok
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading
import json as json_lib
import uuid
import shutil
import time as _time

if NGROK_TOKEN == "YOUR_TOKEN_HERE":
    raise ValueError("‚ùå Paste your ngrok token above! Get it free at https://ngrok.com")

ngrok.set_auth_token(NGROK_TOKEN)

app = FastAPI(title="ChatCut")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

# Store completed files for download
completed_files = {}

# Job progress storage - key: job_id, value: progress dict
job_progress = {}

@app.get("/health")
def health():
    return {"status": "ok", "gpu": DEVICE, "test_mode": TEST_MODE, "gemini_model": GEMINI_MODEL_ID}

@app.get("/effects")
def effects():
    return {"effects": list(RENDERERS.keys())}

def update_progress(job_id, stage, progress, message, **extra):
    """Update progress for a job."""
    job_progress[job_id] = {
        "status": "processing" if stage not in ["complete", "error"] else stage,
        "stage": stage,
        "progress": progress,
        "message": message,
        **extra
    }
    # Enhanced logging with timestamp
    timestamp = _time.strftime("%H:%M:%S")
    print(f"[{timestamp}] [Job {job_id}] {stage}: {progress}% - {message}")

def process_job(job_id, file_path, filename, prompt):
    """Background job to process video - supports MULTIPLE EFFECTS."""
    temp_files = []
    current_video = file_path
    job_start_time = _time.time()
    
    try:
        print(f"\n{'#'*70}")
        print(f"# JOB {job_id} STARTED")
        print(f"{'#'*70}")
        print(f"   File: {filename}")
        print(f"   Prompt: '{prompt}'")
        print(f"   TEST_MODE: {TEST_MODE}")
        print(f"   Device: {DEVICE}")
        print(f"{'#'*70}\n")
        
        update_progress(job_id, "tracking", 0, "Starting object tracking...")
        global tracks_df_cached
        tracks_df_cached = None

        # Get video info
        cap = cv2.VideoCapture(current_video)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
        fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        w, h = int(cap.get(3)), int(cap.get(4))
        duration = total_frames / fps if fps else 0
        cap.release()

        print(f"üìπ VIDEO INFO:")
        print(f"   Resolution: {w}x{h}")
        print(f"   FPS: {fps:.2f}")
        print(f"   Frames: {total_frames}")
        print(f"   Duration: {duration:.2f}s\n")

        # Run tracking with progress updates
        track_start = _time.time()
        model = load_model(use_seg=True)
        name_map = _build_name_map(model)

        print(f"üîç TRACKING PHASE:")
        print(f"   Model: {DEFAULT_MODELS['seg']}")
        print(f"   Tracker: ByteTrack")
        
        stream = model.track(source=current_video, imgsz=960, tracker='bytetrack.yaml', stream=True,
                            conf=0.25, iou=0.45, vid_stride=1, device=DEVICE, verbose=False, persist=True)

        frames_data = []
        cursor = 0
        last_progress = -1
        unique_track_ids = set()

        for result in stream:
            dets = []
            if result.boxes is not None and result.boxes.id is not None:
                ids = result.boxes.id.int().cpu().tolist()
                xyxy = result.boxes.xyxy.cpu().tolist()
                confs = result.boxes.conf.cpu().tolist()
                clss = result.boxes.cls.int().cpu().tolist()
                masks = result.masks.data.cpu().numpy() if result.masks else None
                for i, tid in enumerate(ids):
                    unique_track_ids.add(tid)
                    dets.append({'id': int(tid), 'cls': name_map.get(clss[i], str(clss[i])),
                                'conf': float(confs[i]), 'bbox_xyxy': [float(v) for v in xyxy[i]],
                                'mask_rle': encode_mask(masks[i]) if masks is not None else None})
            frames_data.append({'frame_index': cursor, 't': cursor/fps, 'detections': dets})
            cursor += 1

            track_progress = int((cursor / total_frames) * 40) if total_frames > 0 else 0
            if track_progress >= last_progress + 5:
                last_progress = track_progress
                update_progress(job_id, "tracking", track_progress,
                              f"Tracking frame {cursor}/{total_frames}...")

        track_time = _time.time() - track_start
        tracks = {
            'video_path': current_video,
            'fps': fps,
            'size': [w, h],
            'duration': duration,
            'frames': frames_data
        }

        print(f"   ‚úÖ Tracking complete in {track_time:.1f}s")
        print(f"   Frames processed: {len(frames_data)}")
        print(f"   Unique tracks: {len(unique_track_ids)}")
        
        update_progress(job_id, "tracking", 40, f"Tracked {len(frames_data)} frames, {len(unique_track_ids)} objects")

        # Parse command with Gemini
        parse_start = _time.time()
        update_progress(job_id, "parsing", 45, "Parsing command with Gemini...")
        cmds = parse_nl_to_dsl(prompt, tracks['duration'])
        parse_time = _time.time() - parse_start
        
        num_effects = len(cmds)
        effect_names = [cmd.effect for cmd in cmds]
        
        print(f"\n‚è±Ô∏è  Gemini parsing took {parse_time:.2f}s")
        
        update_progress(job_id, "parsing", 50, 
                       f"Found {num_effects} effect(s): {', '.join(effect_names)}")
        
        # Process each effect sequentially
        for i, cmd in enumerate(cmds):
            effect_num = i + 1
            progress_base = 50 + int((i / num_effects) * 45)
            
            print(f"\n{'‚îÄ'*60}")
            print(f"üé¨ PROCESSING EFFECT {effect_num}/{num_effects}: {cmd.effect}")
            print(f"{'‚îÄ'*60}")
            
            if i > 0:
                tracks_df_cached = None
                update_progress(job_id, f"effect_{effect_num}", progress_base, 
                              f"[{effect_num}/{num_effects}] Re-tracking for {cmd.effect}...")
                
                retrack_start = _time.time()
                tracks_rerun = detect_and_track(current_video, use_seg=True, frame_stride=1)
                tracks = tracks_rerun
                print(f"   Re-tracking took {_time.time() - retrack_start:.1f}s")
            
            update_progress(job_id, f"effect_{effect_num}", progress_base + 5, 
                          f"[{effect_num}/{num_effects}] Planning {cmd.effect}...")
            
            plan_start = _time.time()
            plan = plan_effect(cmd, tracks)
            plan['video_path'] = current_video
            print(f"   Planning took {_time.time() - plan_start:.1f}s")
            
            update_progress(job_id, f"effect_{effect_num}", progress_base + 10, 
                          f"[{effect_num}/{num_effects}] Rendering {cmd.effect}...")
            
            if i < num_effects - 1:
                tmp_out = tempfile.NamedTemporaryFile(
                    delete=False, suffix='.mp4', 
                    prefix=f'effect_{effect_num}_'
                )
                tmp_out.close()
                out_path = tmp_out.name
                temp_files.append(out_path)
            else:
                out_name = f"processed_{filename}"
                out_path = str(EXPORT_DIR / out_name)
            
            effect_name = plan['effect']
            if effect_name not in RENDERERS:
                raise ValueError(f"Unknown effect: {effect_name}. Available: {list(RENDERERS.keys())}")
            
            render_start = _time.time()
            RENDERERS[effect_name](plan, out_path)
            render_time = _time.time() - render_start
            
            current_video = out_path
            
            print(f"   ‚úÖ Rendering took {render_time:.1f}s")
            update_progress(job_id, f"effect_{effect_num}", progress_base + 15, 
                          f"[{effect_num}/{num_effects}] {cmd.effect} complete ({render_time:.1f}s)")
        
        # Final output
        out_name = f"processed_{filename}"
        final_path = str(EXPORT_DIR / out_name)
        
        if current_video != final_path:
            shutil.copy2(current_video, final_path)

        completed_files[out_name] = final_path

        total_time = _time.time() - job_start_time
        
        print(f"\n{'#'*70}")
        print(f"# JOB {job_id} COMPLETE")
        print(f"{'#'*70}")
        print(f"   Total time: {total_time:.1f}s")
        print(f"   Effects applied: {', '.join(effect_names)}")
        print(f"   Output: {final_path}")
        print(f"{'#'*70}\n")
        
        update_progress(job_id, "complete", 100, 
                       f"Complete in {total_time:.1f}s! Applied: {', '.join(effect_names)}",
                       file_ready=True,
                       filename=out_name,
                       output_path=final_path,
                       effects_applied=effect_names,
                       total_time_seconds=round(total_time, 1),
                       download_url=f"/download/{out_name}")

    except Exception as e:
        total_time = _time.time() - job_start_time
        print(f"\n{'!'*70}")
        print(f"! JOB {job_id} FAILED after {total_time:.1f}s")
        print(f"{'!'*70}")
        print(f"   Error: {e}")
        traceback.print_exc()
        print(f"{'!'*70}\n")
        
        update_progress(job_id, "error", 0, f"Error after {total_time:.1f}s: {str(e)}", 
                       error=str(e), total_time_seconds=round(total_time, 1))
    finally:
        for tmp in temp_files:
            if os.path.exists(tmp):
                try:
                    os.unlink(tmp)
                except:
                    pass
        
        if os.path.exists(file_path) and file_path != current_video:
            try:
                os.unlink(file_path)
            except:
                pass

@app.post("/start-job")
async def start_job(file: UploadFile = File(...), prompt: str = Form(...)):
    """Start a video processing job - returns job_id immediately."""
    job_id = str(uuid.uuid4())[:8]

    tmp_dir = tempfile.mkdtemp()
    file_path = os.path.join(tmp_dir, file.filename)

    with open(file_path, 'wb') as f:
        content = await file.read()
        f.write(content)

    file_size_mb = len(content) / (1024 * 1024)
    
    print(f"\nüì• NEW JOB RECEIVED")
    print(f"   Job ID: {job_id}")
    print(f"   File: {file.filename} ({file_size_mb:.1f} MB)")
    print(f"   Prompt: '{prompt}'")

    update_progress(job_id, "upload", 5, f"Received {file.filename} ({file_size_mb:.1f} MB)")

    thread = threading.Thread(
        target=process_job,
        args=(job_id, file_path, file.filename, prompt),
        daemon=True
    )
    thread.start()

    return {
        "job_id": job_id,
        "status": "started",
        "message": f"Processing started for {file.filename}",
        "file_size_mb": round(file_size_mb, 1),
        "test_mode": TEST_MODE
    }

@app.get("/progress/{job_id}")
def get_progress(job_id: str):
    """Get progress for a job."""
    if job_id not in job_progress:
        return {"status": "not_found", "error": f"Job {job_id} not found"}
    return job_progress[job_id]

@app.get("/download/{filename}")
async def download(filename: str):
    """Download a processed video file."""
    if filename not in completed_files:
        raise HTTPException(404, f"File not found: {filename}")

    path = completed_files[filename]
    if not os.path.exists(path):
        raise HTTPException(404, f"File no longer exists: {filename}")

    return FileResponse(str(path), filename=filename, media_type="video/mp4")

# Keep old /process endpoint for backwards compatibility
@app.post("/process")
async def process(file: UploadFile = File(...), prompt: str = Form(...)):
    tmp = None
    try:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1] or '.mp4')
        tmp.write(await file.read())
        tmp.close()
        
        print(f"\nüì• SYNC REQUEST: {file.filename}")
        print(f"üìù Prompt: '{prompt}'")

        global tracks_df_cached
        tracks_df_cached = None
        tracks = detect_and_track(tmp.name, use_seg=True, frame_stride=1)

        cmds = parse_nl_to_dsl(prompt, tracks['duration'])
        
        current_video = tmp.name
        temp_outputs = []
        
        for i, cmd in enumerate(cmds):
            plan = plan_effect(cmd, tracks)
            plan['video_path'] = current_video
            
            if i < len(cmds) - 1:
                tmp_out = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
                tmp_out.close()
                out_path = tmp_out.name
                temp_outputs.append(out_path)
            else:
                out_name = f"processed_{file.filename}"
                out_path = str(EXPORT_DIR / out_name)
            
            RENDERERS[plan['effect']](plan, out_path)
            current_video = out_path
            
            if i < len(cmds) - 1:
                tracks_df_cached = None
                tracks = detect_and_track(current_video, use_seg=True, frame_stride=1)
        
        print(f"‚úÖ Sync request complete: {out_path}")

        for tmp_out in temp_outputs:
            if os.path.exists(tmp_out):
                os.unlink(tmp_out)

        return FileResponse(str(out_path), filename=out_name, media_type="video/mp4")

    except Exception as e:
        print(f"‚ùå Sync request failed: {e}")
        traceback.print_exc()
        raise HTTPException(500, str(e))
    finally:
        if tmp and os.path.exists(tmp.name): os.unlink(tmp.name)

# Start server
def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")

print("üöÄ Starting ChatCut server...")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()

import time
time.sleep(2)

url = ngrok.connect(8000)
print("")
print("=" * 60)
print("üéâ CHATCUT SERVER READY")
print("=" * 60)
print(f"")
print(f"üì° Copy this URL into Premiere Pro:")
print(f"")
print(f"   {url}")
print(f"")
print("=" * 60)
print("")
print("‚öôÔ∏è  CONFIGURATION:")
print(f"   ‚Ä¢ TEST_MODE: {TEST_MODE}" + (" (480p, ultrafast)" if TEST_MODE else " (full quality)"))
print(f"   ‚Ä¢ Device: {DEVICE.upper()}")
print(f"   ‚Ä¢ Gemini: {GEMINI_MODEL_ID}")
print(f"   ‚Ä¢ CLIP: ViT-L/14")
print("")
print("üîß PIPELINE:")
print("   1. Upload ‚Üí YOLO tracking (ByteTrack)")
print("   2. Parse ‚Üí Gemini NL‚ÜíDSL")
print("   3. Plan ‚Üí CLIP semantic match + Gemini rerank")
print("   4. Render ‚Üí MoviePy + OpenCV")
print("")
print("üìä AVAILABLE EFFECTS:")
for eff in RENDERERS.keys():
    keywords = EFFECT_KEYWORDS.get(eff, [])
    print(f"   ‚Ä¢ {eff}: {', '.join(keywords)}")
print("")
print("üåê ENDPOINTS:")
print("   POST /start-job       - Async processing (returns job_id)")
print("   GET  /progress/{id}   - Poll progress (0-100%)")
print("   GET  /download/{file} - Download result")
print("   POST /process         - Sync endpoint (legacy)")
print("   GET  /health          - Health check")
print("   GET  /effects         - List available effects")
print("")
print("‚úÖ Server running! Logs will appear below.")
print("   To stop: Runtime ‚Üí Restart runtime")
print("=" * 60)

üöÄ Starting ChatCut server...

üéâ CHATCUT SERVER READY

üì° Copy this URL into Premiere Pro:

   NgrokTunnel: "https://eac976108fd9.ngrok-free.app" -> "http://localhost:8000"


‚öôÔ∏è  CONFIGURATION:
   ‚Ä¢ TEST_MODE: False (full quality)
   ‚Ä¢ Device: CUDA
   ‚Ä¢ Gemini: gemini-2.5-flash
   ‚Ä¢ CLIP: ViT-L/14

üîß PIPELINE:
   1. Upload ‚Üí YOLO tracking (ByteTrack)
   2. Parse ‚Üí Gemini NL‚ÜíDSL
   3. Plan ‚Üí CLIP semantic match + Gemini rerank
   4. Render ‚Üí MoviePy + OpenCV

üìä AVAILABLE EFFECTS:
   ‚Ä¢ ZoomFollow: zoom, punch in, follow
   ‚Ä¢ Spotlight: spotlight, highlight
   ‚Ä¢ BlurBackground: blur background, background blur
   ‚Ä¢ PixelateObject: pixelate
   ‚Ä¢ Callout: callout, label
   ‚Ä¢ PathOverlay: path, trajectory

üåê ENDPOINTS:
   POST /start-job       - Async processing (returns job_id)
   GET  /progress/{id}   - Poll progress (0-100%)
   GET  /download/{file} - Download result
   POST /process         - Sync endpoint (legacy)
   GET  /health        


üì• NEW JOB RECEIVED
   Job ID: cbc6ead4
   File: Jay_Prakash_Guiding_at_Wikimedia_Hackathon_Kochi_2024_0.00_5.83.mp4 (3.4 MB)
   Prompt: 'zoom in on the person in green and center them'
[07:45:46] [Job cbc6ead4] upload: 5% - Received Jay_Prakash_Guiding_at_Wikimedia_Hackathon_Kochi_2024_0.00_5.83.mp4 (3.4 MB)

######################################################################
# JOB cbc6ead4 STARTED
######################################################################
   File: Jay_Prakash_Guiding_at_Wikimedia_Hackathon_Kochi_2024_0.00_5.83.mp4
   Prompt: 'zoom in on the person in green and center them'
   TEST_MODE: False
   Device: cuda
######################################################################

[07:45:46] [Job cbc6ead4] tracking: 0% - Starting object tracking...
üìπ VIDEO INFO:
   Resolution: 608x1080
   FPS: 30.00
   Frames: 175
   Duration: 5.83s

üîç TRACKING PHASE:
   Model: yolo11n-seg.pt
   Tracker: ByteTrack
[07:45:51] [Job cbc6ead4] tracking: 4% - Trackin