<a href="https://colab.research.google.com/github/maryzhang1028/project-0/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Project 0 - Automated Dance Formation Translator
========================================================
A computer vision system for tracking dancer positions in videos and generating
formation visualizations.
Claude and OpenAI are used to structure, generate, and debug code in this project.


Author: Mary Zhang
Date: 2025

Instructions:
1. Download any of the mp4 files from the sample_videos_upload folder onto your local computer
2. Runall code
3. Upload the sample file you desire or any dance choreography video you want (under 30 sec) when prompted in the GUI
4. Observe gallery or download batch files

In [5]:
#@title Mounting Drive - Sample Videos (Please feel free to select videos to upload from this folder)
# Mount Drive

"""
Cell: Download Sample Videos from Google Drive
This ensures all users get the same sample files
"""

import gdown
import os

# Create sample_videos directory
os.makedirs('sample_videos_upload', exist_ok=True)

# Your Google Drive file IDs (from shareable links)
sample_videos = {
    'dance_performance_1.mp4': '10lEHcVoCWwSeXuskXELWyMn2PCY8q2d8',
    'dance_performance_2.mp4': '1KYm3EJWdUnAhlqd6MicRqnIesU2yPuaq',
    'dance_performance_3.mp4': '1MBOe_nNmeOHcTye2G7r4eYZZERG4TBG',
}

print("📥 Downloading sample videos...")
for filename, file_id in sample_videos.items():
    url = f'https://drive.google.com/drive/u/2/folders/1VmhSsLlQd5B1lKican9GwahSB-YR6Mqi={file_id}'
    output_path = f'sample_videos_upload/{filename}'

    if not os.path.exists(output_path):
        gdown.download(url, output_path, quiet=False)
        print(f"✅ Downloaded {filename}")
    else:
        print(f"✓ {filename} already exists")

print("\n✅ All sample videos ready in 'sample_videos_upload' folder!")
print("Users can also upload their own videos using the interface.")

📥 Downloading sample videos...
✓ dance_performance_1.mp4 already exists
✓ dance_performance_2.mp4 already exists
✓ dance_performance_3.mp4 already exists

✅ All sample videos ready in 'sample_videos_upload' folder!
Users can also upload their own videos using the interface.


In [6]:
#@title Install Packages

"""
Run this cell first to install required packages.
"""

!pip install opencv-python ultralytics numpy matplotlib moviepy gradio scipy --quiet
print("✅ All packages installed successfully!")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━[0m [32m0.6/1.1 MB[0m [31m17.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25h✅ All packages installed successfully!


In [7]:
#@title Import Libraries
# Import all necessary Python libraries and modules.

import os
import cv2
import zipfile
import tempfile
import numpy as np
import matplotlib
matplotlib.use("Agg")  # Use non-interactive backend for server environments
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.backends.backend_agg import FigureCanvasAgg
from ultralytics import YOLO
from moviepy.editor import VideoFileClip
import gradio as gr
from scipy.optimize import linear_sum_assignment

# Suppress numpy warnings for cleaner output
np.seterr(all="ignore")

print("✅ All libraries imported successfully!")

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


  IMAGEMAGICK_BINARY = r"C:\Program Files\ImageMagick-6.8.8-Q16\magick.exe"
  lines_video = [l for l in lines if ' Video: ' in l and re.search('\d+x\d+', l)]
  rotation_lines = [l for l in lines if 'rotate          :' in l and re.search('\d+$', l)]
  match = re.search('\d+$', rotation_line)
  if event.key is 'enter':



✅ All libraries imported successfully!


In [8]:
#@title Helper Functions (Data Processing and Geometric Calculations)
# Helper Functions

"""
Helper functions for data processing and geometric calculations.
"""

def _float(v):
    """
    Safely convert any value to float.

    Args:
        v: Value to convert (can be tensor, numpy array, or scalar)

    Returns:
        float: Converted value
    """
    try:
        return float(v.item()) if hasattr(v, "item") else float(v)
    except Exception:
        return float(v)


def _ensure_dir(path):
    """
    Create directory if it doesn't exist.

    Args:
        path (str): Directory path to create

    Returns:
        str: The same path (for chaining)
    """
    os.makedirs(path, exist_ok=True)
    return path


def iou_xyxy(box_a, box_b):
    """
    Calculate Intersection over Union (IoU) for two bounding boxes.

    IoU measures the overlap between two boxes, ranging from 0 (no overlap)
    to 1 (perfect overlap). Used for duplicate detection removal.

    Args:
        box_a (list): First box [x1, y1, x2, y2]
        box_b (list): Second box [x1, y1, x2, y2]

    Returns:
        float: IoU score between 0 and 1
    """
    x1 = max(box_a[0], box_b[0])
    y1 = max(box_a[1], box_b[1])
    x2 = min(box_a[2], box_b[2])
    y2 = min(box_a[3], box_b[3])

    if x2 <= x1 or y2 <= y1:
        return 0.0

    intersection = (x2 - x1) * (y2 - y1)
    area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
    area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
    union = area_a + area_b - intersection

    return intersection / max(1e-6, union)


def nms_merge(candidates, iou_threshold=0.5):
    """
    Non-Maximum Suppression to remove duplicate detections.

    When multiple detections overlap significantly, keep only the one
    with highest confidence score. This prevents counting the same
    dancer multiple times.

    Args:
        candidates (list): List of detection dictionaries with 'bbox' and 'conf' keys
        iou_threshold (float): Minimum IoU to consider boxes as duplicates

    Returns:
        list: Filtered list with duplicates removed
    """
    if not candidates:
        return []

    keep = []
    used = [False] * len(candidates)

    # Sort by confidence (highest first)
    order = sorted(
        range(len(candidates)),
        key=lambda i: candidates[i]["conf"],
        reverse=True
    )

    for i in order:
        if used[i]:
            continue

        keep.append(candidates[i])
        used[i] = True

        # Mark overlapping boxes as used
        for j in order:
            if used[j]:
                continue
            if iou_xyxy(candidates[i]["bbox"], candidates[j]["bbox"]) >= iou_threshold:
                used[j] = True

    return keep


def enhance_frame(rgb_image):
    """
    Enhance dark regions in video frame using CLAHE.

    This technique improves visibility of dancers in poorly lit areas
    by adaptively enhancing local contrast while preventing over-amplification.

    Args:
        rgb_image (np.ndarray): Input image in RGB format

    Returns:
        np.ndarray: Enhanced image with improved visibility
    """
    # Convert to YCrCb color space (Y = luminance, Cr/Cb = chrominance)
    ycrcb = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2YCrCb)
    y, cr, cb = cv2.split(ycrcb)

    # Apply CLAHE to luminance channel
    clahe = cv2.createCLAHE(
        clipLimit=2.0,      # Contrast limit to prevent noise amplification
        tileGridSize=(8, 8)  # Size of grid for local histogram equalization
    )
    y = clahe.apply(y)

    # Merge channels back
    ycrcb = cv2.merge((y, cr, cb))
    enhanced = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2RGB)

    # Apply gamma correction (brighten dark areas more than bright areas)
    gamma = 1.2
    enhanced = np.clip(
        ((enhanced / 255.0) ** (1 / gamma)) * 255.0,
        0, 255
    ).astype(np.uint8)

    return enhanced


def umeyama(src, dst, estimate_scale=True):
    """
    Estimate optimal similarity transformation between two point sets.

    The Umeyama algorithm finds the best rotation, translation, and scale
    that aligns source points to destination points. Used for tracking
    formation changes between frames.

    Args:
        src (np.ndarray): Source points, shape (N, 2)
        dst (np.ndarray): Destination points, shape (N, 2)
        estimate_scale (bool): Whether to estimate scale factor

    Returns:
        tuple: (scale, rotation_matrix, translation_vector)
    """
    src = np.asarray(src, dtype=np.float64)
    dst = np.asarray(dst, dtype=np.float64)

    assert src.shape == dst.shape and src.shape[1] == 2

    n = src.shape[0]

    # Center the point sets
    mu_src = src.mean(0)
    mu_dst = dst.mean(0)
    src_centered = src - mu_src
    dst_centered = dst - mu_dst

    # Compute covariance matrix
    covariance = (dst_centered.T @ src_centered) / n

    # Singular Value Decomposition
    U, D, Vt = np.linalg.svd(covariance)

    # Ensure proper rotation (det(R) = 1, not -1)
    S = np.eye(2)
    if np.linalg.det(U @ Vt) < 0:
        S[1, 1] = -1.0

    # Compute rotation matrix
    R = U @ S @ Vt

    # Compute scale if requested
    if estimate_scale:
        var_src = (src_centered ** 2).sum() / n
        scale = np.trace(np.diag(D) @ S) / (var_src + 1e-12)
    else:
        scale = 1.0

    # Compute translation
    t = mu_dst - scale * (R @ mu_src)

    return float(scale), R, t

In [9]:
#@title Dancer Formation Rendering Class
# Formation REndering Class

"""
Class for visualizing dancer positions on a stage diagram.
"""

class FormationRenderer:
    """
    Renders dancer formations on a virtual stage with grid lines and position markers.

    The renderer creates a top-down view of the stage showing:
    - Grid lines for spatial reference
    - Colored circles for each dancer with ID numbers
    - Stage boundaries and orientation labels
    - Center stage marker
    """

    def __init__(self, stage_width=16, stage_height=10):
        """
        Initialize the formation renderer.

        Args:
            stage_width (float): Stage width in grid units (default 16)
            stage_height (float): Stage depth in grid units (default 10)
        """
        self.stage_width = float(stage_width)
        self.stage_height = float(stage_height)

        # Extended color palette for up to 24 dancers
        self.colors = [
            "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57", "#DDA0DD",
            "#98D8C8", "#F7DC6F", "#BB8FCE", "#85C1E2", "#F8B500", "#FF6F61",
            "#00B894", "#6C5CE7", "#E84393", "#00CEC9", "#FDCB6E", "#74B9FF",
            "#55EFC4", "#A29BFE", "#E17055", "#E84393", "#E1BEE7", "#26A69A"
        ]

        # Visual parameters
        self.dancer_radius = 0.35    # Size of dancer circles
        self.line_width = 2.5        # Border width for circles
        self.opacity = 0.95          # Transparency of dancer markers

    def render(self, stage_points, target_width, target_height):
        """
        Render formation visualization as an image.

        Args:
            stage_points (list): List of dicts with 'id', 'x', 'y' keys for each dancer
            target_width (int): Output image width in pixels
            target_height (int): Output image height in pixels

        Returns:
            np.ndarray: RGB image array of the rendered formation
        """
        # Setup matplotlib figure with dark background
        dpi = 100
        fig = plt.figure(
            figsize=(target_width / dpi, target_height / dpi),
            dpi=dpi,
            facecolor="#2C2C2C"  # Dark gray background
        )
        ax = fig.add_subplot(111)
        ax.set_facecolor("#2C2C2C")

        # Draw grid lines for spatial reference
        for x in range(int(self.stage_width) + 1):
            ax.axvline(x, color="#444444", linewidth=0.5, alpha=0.3)
        for y in range(int(self.stage_height) + 1):
            ax.axhline(y, color="#444444", linewidth=0.5, alpha=0.3)

        # Draw stage boundary (red outline)
        stage_x = [0, self.stage_width, self.stage_width, 0, 0]
        stage_y = [0, 0, self.stage_height, self.stage_height, 0]
        ax.plot(stage_x, stage_y, color="#FF1744", linewidth=2)

        # Draw each dancer
        for dancer in stage_points:
            dancer_id = int(dancer["id"])

            # Constrain position to stage boundaries
            x = max(0.0, min(self.stage_width, float(dancer["x"])))
            y = max(0.0, min(self.stage_height, float(dancer["y"])))

            # Select color from palette (cycles if more dancers than colors)
            color = self.colors[(dancer_id - 1) % len(self.colors)]

            # Create circle for dancer
            circle = Circle(
                (x, y),
                self.dancer_radius,
                facecolor=color,
                edgecolor="white",
                linewidth=self.line_width,
                alpha=self.opacity
            )
            ax.add_patch(circle)

            # Add ID number
            ax.text(
                x, y, str(dancer_id),
                color="white",
                fontsize=12,
                fontweight="bold",
                ha="center",
                va="center",
                alpha=1.0
            )

        # Draw center stage marker (red X)
        ax.plot(
            self.stage_width / 2,
            self.stage_height / 2,
            "x",
            color="#FF1744",
            markersize=10,
            markeredgewidth=2
        )

        # Add orientation labels
        ax.text(
            self.stage_width / 2, -0.5,
            "AUDIENCE",
            ha="center", va="top",
            color="white", fontsize=10
        )
        ax.text(
            self.stage_width / 2, self.stage_height + 0.5,
            "BACKSTAGE",
            ha="center", va="bottom",
            color="white", fontsize=10
        )

        # Configure plot appearance
        ax.set_xlim(-0.5, self.stage_width + 0.5)
        ax.set_ylim(-1.0, self.stage_height + 1.0)
        ax.set_aspect("equal")
        ax.axis("off")
        fig.tight_layout(pad=0)

        # Convert matplotlib figure to numpy array
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buffer = canvas.buffer_rgba()
        image = np.asarray(buffer)[:, :, :3].copy()  # Drop alpha channel
        plt.close(fig)

        return image


In [10]:
#@title Dancer Position Predictor Class
# Dancer Position Predictor Class

"""
Main tracking class that processes video frames to extract dancer positions.
This entire class must be in ONE CELL to maintain proper indentation.
"""

class DancerPositionEstimator:
    """
    Main class for tracking dancer positions across video frames.

    This class implements a robust multi-stage detection pipeline:
    1. Primary detection using pose estimation (YOLOv8-pose)
    2. Fallback to object detection if insufficient dancers found
    3. Formation-aware tracking across frames
    4. Velocity and rigid transformation prediction
    """

    def __init__(
        self,
        num_dancers,
        model_size="m",
        stage_width=16,
        stage_height=10,
        size_weight=0.2,
        formation_inertia=0.6,
        rigid_threshold=1.0
    ):
        """
        Initialize the dancer position estimator.

        Args:
            num_dancers (int): Expected number of dancers to track
            model_size (str): YOLO model size - 'n' (nano), 'm' (medium), 'l' (large)
            stage_width (float): Virtual stage width in grid units
            stage_height (float): Virtual stage depth in grid units
            size_weight (float): Weight for size-based depth estimation (0-1)
            formation_inertia (float): How much to trust formation rigidity (0-1)
            rigid_threshold (float): Maximum distance for rigid transformation inliers
        """
        assert num_dancers > 0, "Number of dancers must be positive"

        self.num_dancers = int(num_dancers)
        self.stage_width = float(stage_width)
        self.stage_height = float(stage_height)
        self.size_weight = float(size_weight)
        self.formation_inertia = float(formation_inertia)
        self.rigid_threshold = float(rigid_threshold)

        # Load YOLO models
        print(f"Loading YOLOv8{model_size} models...")
        self.pose_model = YOLO(f"yolov8{model_size}-pose.pt")
        self.detector_model = YOLO(f"yolov8{model_size}.pt")
        print("✅ Models loaded successfully!")

        # Keypoint indices for ankles (COCO format)
        self.left_ankle_idx = 15
        self.right_ankle_idx = 16

        # Tracking state
        self.last_positions = None      # Current frame positions
        self.previous_positions = None  # Previous frame positions (for velocity)
        self.initialized = False

    def _get_pose_candidates(self, image, confidence=0.25, use_flip=False):
        """
        Extract pose-based dancer candidates from an image.
        """
        height, width = image.shape[:2]
        candidates = []

        def process_image(img, is_flipped=False):
            """Process single image through pose model."""
            results = self.pose_model(img, conf=confidence, verbose=False)
            if not results or results[0].boxes is None:
                return

            result = results[0]
            boxes = result.boxes
            keypoints = getattr(result, "keypoints", None)

            # Sort by confidence
            order = sorted(
                range(len(boxes)),
                key=lambda i: _float(boxes.conf[i]),
                reverse=True
            )

            for i in order:
                try:
                    box = boxes[i]
                    conf_score = _float(box.conf)
                    x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(np.float32)

                    # Adjust coordinates if image was flipped
                    if is_flipped:
                        x1, x2 = width - x2, width - x1

                    # Try to get ankle position from keypoints
                    ankle_point = None
                    if keypoints and keypoints.xy is not None:
                        kp_coords = keypoints.xy[i].cpu().numpy().astype(np.float32)
                        ankle_coords = []

                        # Check both ankles
                        for ankle_idx in (self.left_ankle_idx, self.right_ankle_idx):
                            if ankle_idx < kp_coords.shape[0]:
                                ax, ay = kp_coords[ankle_idx]
                                if is_flipped:
                                    ax = width - ax
                                # Validate coordinates
                                if np.isfinite(ax) and np.isfinite(ay) and ax > 0 and ay > 0:
                                    ankle_coords.append((ax, ay))

                        # Use average of detected ankles
                        if ankle_coords:
                            ankle_x = float(np.mean([a[0] for a in ankle_coords]))
                            ankle_y = float(np.mean([a[1] for a in ankle_coords]))
                            ankle_point = (ankle_x, ankle_y)

                    # Fallback: use bottom-center of bounding box
                    if ankle_point is None:
                        ankle_point = ((x1 + x2) * 0.5, y2)

                    candidates.append({
                        "bbox": [x1, y1, x2, y2],
                        "ankle": ankle_point,
                        "conf": conf_score
                    })
                except Exception:
                    continue

        # Process original image
        process_image(image, is_flipped=False)

        # Process flipped image for better edge detection
        if use_flip:
            flipped = np.ascontiguousarray(image[:, ::-1, :])
            process_image(flipped, is_flipped=True)

        return candidates

    def _get_all_candidates(self, image, confidence):
        """
        Get candidates from both original and enhanced images.
        """
        # Detect on original image
        candidates_raw = self._get_pose_candidates(image, confidence, use_flip=True)

        # Detect on enhanced image (better for dark areas)
        enhanced = enhance_frame(image)
        candidates_enhanced = self._get_pose_candidates(enhanced, confidence, use_flip=True)

        # Merge and remove duplicates
        return nms_merge(candidates_raw + candidates_enhanced, iou_threshold=0.5)

    def _get_detector_candidates(self, image, confidence=0.2):
        """
        Fallback detection using general object detector.
        """
        candidates = []

        # Try both original and enhanced images
        for img in (image, enhance_frame(image)):
            try:
                # Detect class 0 (person) only
                results = self.detector_model(
                    img, conf=confidence, classes=[0], verbose=False
                )
                if results and results[0].boxes is not None:
                    result = results[0]
                    boxes = result.boxes

                    order = sorted(
                        range(len(boxes)),
                        key=lambda i: _float(boxes.conf[i]),
                        reverse=True
                    )

                    for i in order:
                        box = boxes[i]
                        conf_score = _float(box.conf)
                        x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(np.float32)

                        # Use bottom-center as ankle approximation
                        ankle_point = ((x1 + x2) * 0.5, y2)

                        candidates.append({
                            "bbox": [x1, y1, x2, y2],
                            "ankle": ankle_point,
                            "conf": conf_score
                        })
            except Exception:
                continue

        return nms_merge(candidates, iou_threshold=0.5)

    def _map_to_stage(self, ankle_xy, bbox_height, img_width, img_height, median_height):
        """
        Map image pixel coordinates to stage grid coordinates.
        """
        ax, ay = ankle_xy

        # Normalize x-coordinate to [0, 1]
        x_normalized = np.clip(ax / img_width, 0, 1)

        # Normalize y-coordinate to [0, 1]
        y_normalized = np.clip(ay / img_height, 0, 1)

        # Invert y-axis (top of image = back of stage)
        y_from_position = 1.0 - y_normalized

        # Estimate depth from relative size
        if median_height > 0:
            size_ratio = median_height / max(1.0, bbox_height)
            # Smaller dancers are assumed to be further away
            size_depth = np.clip(size_ratio / 1.5, 0.0, 1.0)
        else:
            size_depth = y_from_position

        # Combine position and size cues
        y_stage = (1.0 - self.size_weight) * y_from_position + self.size_weight * size_depth
        y_stage = float(np.clip(y_stage, 0, 1))

        # Map to stage coordinates
        return np.array([
            x_normalized * self.stage_width,
            y_stage * self.stage_height
        ], dtype=np.float32)

    def _compute_cost_matrix(self, predicted, measured):
        """
        Compute assignment cost matrix using spatial distance.
        """
        num_pred = predicted.shape[0]
        num_meas = measured.shape[0]
        cost = np.zeros((num_pred, num_meas), dtype=np.float32)

        for i in range(num_pred):
            for j in range(num_meas):
                dx = predicted[i, 0] - measured[j, 0]
                dy = predicted[i, 1] - measured[j, 1]
                # Euclidean distance plus small depth preference
                cost[i, j] = np.hypot(dx, dy) + 0.1 * abs(dy)

        return cost

    def _predict_with_velocity(self):
        """
        Predict next positions using constant velocity model.
        """
        if self.last_positions is None:
            return None
        if self.previous_positions is None:
            return self.last_positions.copy()

        # Compute velocity from last two frames
        velocity = self.last_positions - self.previous_positions
        # Extrapolate
        return self.last_positions + velocity

    def _initialize_positions(self, measurements):
        """
        Initialize dancer positions for the first frame.
        """
        if measurements.shape[0] == 0:
            # No detections: place dancers in a line across stage
            y = self.stage_height * 0.5
            xs = np.linspace(
                self.stage_width * 0.2,
                self.stage_width * 0.8,
                self.num_dancers,
                dtype=np.float32
            )
            positions = np.stack(
                [xs, np.full_like(xs, y, dtype=np.float32)],
                axis=1
            )
            self.previous_positions = None
            self.last_positions = positions.copy()
            self.initialized = True
            return

        # Sort detections by x-coordinate
        measurements = measurements[np.argsort(measurements[:, 0])]
        count = measurements.shape[0]
        median_y = float(np.median(measurements[:, 1]))
        min_x = float(np.min(measurements[:, 0]))
        max_x = float(np.max(measurements[:, 0]))

        # Ensure reasonable spread
        if max_x - min_x < 1e-3:
            min_x = self.stage_width * 0.2
            max_x = self.stage_width * 0.8

        # Create evenly spaced slots
        slots_x = np.linspace(min_x, max_x, self.num_dancers, dtype=np.float32)
        assigned = [False] * count
        positions = np.zeros((self.num_dancers, 2), dtype=np.float32)

        # Assign detections to nearest slots
        for k in range(self.num_dancers):
            slot_x = float(slots_x[k])
            best_idx = -1
            best_dist = 1e9

            for j in range(count):
                if assigned[j]:
                    continue
                dist = abs(float(measurements[j, 0]) - slot_x)
                if dist < best_dist:
                    best_dist = dist
                    best_idx = j

            if best_idx >= 0:
                positions[k] = measurements[best_idx]
                assigned[best_idx] = True
            else:
                # Fill gap with interpolated position
                positions[k] = np.array([slot_x, median_y], dtype=np.float32)

        self.previous_positions = None
        self.last_positions = positions
        self.initialized = True

    def process_frame(self, image):
        """
        Process a single video frame to extract all dancer positions.

        This is the main entry point that orchestrates the entire detection
        and tracking pipeline for each frame.
        """
        height, width = image.shape[:2]

        # STAGE 1: Multi-cascade detection
        try:
            # Try standard confidence threshold
            candidates = self._get_all_candidates(image, confidence=0.25)

            # If not enough dancers found, try lower threshold
            if len(candidates) < self.num_dancers:
                more = self._get_all_candidates(image, confidence=0.10)
                candidates = nms_merge(candidates + more, iou_threshold=0.5)

            # Last resort: use object detector
            if len(candidates) < self.num_dancers:
                detector_cands = self._get_detector_candidates(image, confidence=0.20)
                candidates = nms_merge(candidates + detector_cands, iou_threshold=0.5)
        except Exception:
            candidates = []

        # STAGE 2: Map to stage coordinates
        if len(candidates) > 0:
            # Calculate median height for depth estimation
            heights = [max(1.0, c["bbox"][3] - c["bbox"][1]) for c in candidates]
            median_height = float(np.median(heights))

            stage_measurements = []
            for c in candidates:
                bbox = c["bbox"]
                h = max(1.0, bbox[3] - bbox[1])
                stage_pos = self._map_to_stage(
                    c["ankle"], h, width, height, median_height
                )
                stage_measurements.append(stage_pos)

            stage_measurements = np.array(stage_measurements, dtype=np.float32)
        else:
            stage_measurements = np.zeros((0, 2), dtype=np.float32)

        # STAGE 3: Handle initialization
        if not self.initialized:
            # First frame: initialize positions
            if stage_measurements.shape[0] > self.num_dancers:
                # Too many detections: keep highest confidence ones
                order = np.argsort([-c["conf"] for c in candidates])[:self.num_dancers]
                stage_measurements = np.array(
                    [stage_measurements[i] for i in order],
                    dtype=np.float32
                )
            self._initialize_positions(stage_measurements)

        # Validate state
        if self.last_positions is None or self.last_positions.shape != (self.num_dancers, 2):
            self._initialize_positions(stage_measurements)

        # STAGE 4: Predict next positions
        # Method 1: Velocity-based prediction
        predicted_velocity = self._predict_with_velocity()
        if predicted_velocity is None:
            predicted_velocity = self.last_positions.copy()

        # Method 2: Formation-based prediction (rigid transformation)
        predicted_rigid = None
        if stage_measurements.shape[0] >= 2:
            # Find correspondences for rigid transformation
            cost_matrix = self._compute_cost_matrix(
                self.last_positions, stage_measurements
            )
            rows, cols = linear_sum_assignment(cost_matrix)

            # Find inliers (good matches)
            inlier_rows, inlier_cols = [], []
            for r, c in zip(rows, cols):
                if (r < self.num_dancers and
                    c < stage_measurements.shape[0] and
                    cost_matrix[r, c] <= self.rigid_threshold):
                    inlier_rows.append(r)
                    inlier_cols.append(c)

            # Estimate rigid transformation if enough inliers
            if len(inlier_rows) >= 2:
                src = self.last_positions[inlier_rows]
                dst = stage_measurements[inlier_cols]
                try:
                    scale, rotation, translation = umeyama(src, dst, estimate_scale=True)
                    # Apply transformation to all positions
                    predicted_rigid = (scale * (self.last_positions @ rotation.T)) + translation
                except Exception:
                    predicted_rigid = None

        # Combine predictions
        if predicted_rigid is not None:
            # Weighted combination of velocity and formation predictions
            predicted = (
                self.formation_inertia * predicted_rigid +
                (1.0 - self.formation_inertia) * predicted_velocity
            )
        else:
            predicted = predicted_velocity

        # STAGE 5: Assign measurements to tracks
        if stage_measurements.shape[0] > 0:
            # Solve assignment problem
            cost_matrix = self._compute_cost_matrix(predicted, stage_measurements)
            rows, cols = linear_sum_assignment(cost_matrix)

            new_positions = predicted.copy()
            used_measurements = set()
            deltas = []

            # Apply assigned measurements
            for r, c in zip(rows, cols):
                if r < self.num_dancers and c < stage_measurements.shape[0]:
                    new_positions[r] = stage_measurements[c]
                    used_measurements.add(c)
                    deltas.append(stage_measurements[c] - predicted[r])

            # For unassigned tracks, apply median motion
            if len(deltas) > 0:
                median_delta = np.median(np.stack(deltas, axis=0), axis=0)
            else:
                median_delta = np.zeros(2, dtype=np.float32)

            for r in range(self.num_dancers):
                if not any(row == r for row, _ in zip(rows, cols)):
                    # No measurement assigned: use prediction + median motion
                    new_positions[r] = predicted[r] + median_delta
        else:
            # No measurements: use pure prediction
            new_positions = predicted.copy()

        # STAGE 6: Constrain to stage boundaries
        new_positions[:, 0] = np.clip(new_positions[:, 0], 0.0, self.stage_width)
        new_positions[:, 1] = np.clip(new_positions[:, 1], 0.0, self.stage_height)

        # Update state for next frame
        self.previous_positions = self.last_positions
        self.last_positions = new_positions

        # Format output
        return [
            {
                "id": k + 1,
                "x": float(self.last_positions[k, 0]),
                "y": float(self.last_positions[k, 1])
            }
            for k in range(self.num_dancers)
        ]


In [11]:
#@title Video Processing Function
# Video Processing Function

"""
Functions for sampling frames and processing complete videos.
"""

def sample_video_frames(video_path, num_frames=10):
    """
    Sample frames evenly from a video file.

    Extracts frames at regular intervals throughout the video duration
    to get a representative sample of the entire performance.

    Args:
        video_path (str): Path to input video file
        num_frames (int): Number of frames to extract

    Returns:
        list: List of RGB image arrays
    """
    clip = VideoFileClip(video_path)
    duration = clip.duration

    # Calculate timestamps for even sampling
    timestamps = np.linspace(0.0, max(0.0001, duration - 1e-6), num_frames)

    frames = []
    for i, t in enumerate(timestamps):
        print(f"Extracting frame {i+1}/{num_frames} at {t:.2f}s...")
        frame = np.array(clip.get_frame(float(t)), dtype=np.uint8, copy=True, order="C")
        frames.append(frame)

    clip.close()
    return frames


def process_dance_video(
    video_path,
    expected_dancers,
    model_size="m",
    num_frames=10
):
    """
    Main pipeline to process a dance video and generate formation visualizations.

    This function orchestrates the entire workflow:
    1. Validates inputs
    2. Samples frames from video
    3. Processes each frame to extract positions
    4. Generates formation visualizations
    5. Creates composite images
    6. Packages results in a ZIP file

    Args:
        video_path (str): Path to input video file
        expected_dancers (int): Number of dancers to track
        model_size (str): YOLO model size ('n', 'm', or 'l')
        num_frames (int): Number of frames to sample

    Returns:
        tuple: (composite_paths, zip_path, status_message)
    """
    # Input validation
    if video_path is None or not os.path.exists(video_path):
        return None, None, "⚠️ Please upload a video file."

    try:
        num_dancers = int(float(expected_dancers))
    except Exception:
        return None, None, "⚠️ Number of dancers must be a valid number."

    if num_dancers <= 0:
        return None, None, "⚠️ Number of dancers must be positive."

    if num_dancers > 20:
        return None, None, "⚠️ Maximum 20 dancers supported."

    print(f"\n🎬 Processing video: {video_path}")
    print(f"👥 Tracking {num_dancers} dancers")
    print(f"🖼️ Sampling {num_frames} frames")

    # Sample frames from video
    frames = sample_video_frames(video_path, num_frames)

    # Setup temporary directory for outputs
    temp_dir = tempfile.mkdtemp()
    composites_dir = _ensure_dir(os.path.join(temp_dir, "composites"))

    # Initialize components with fixed parameters
    estimator = DancerPositionEstimator(
        num_dancers=num_dancers,
        model_size=model_size,
        stage_width=16,
        stage_height=10,
        size_weight=0.2,
        formation_inertia=0.6,
        rigid_threshold=1.0
    )

    renderer = FormationRenderer(stage_width=16, stage_height=10)

    composite_paths = []

    # Process each frame
    print("\n🔄 Processing frames...")
    for i, frame in enumerate(frames):
        print(f"Processing frame {i+1}/{num_frames}...")
        try:
            height, width = frame.shape[:2]

            # Extract dancer positions
            positions = estimator.process_frame(frame)

            # Render formation visualization
            formation_image = renderer.render(
                positions,
                target_width=width,
                target_height=height
            )

            # Create composite image (original on top, formation on bottom)
            top_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            bottom_bgr = cv2.cvtColor(formation_image, cv2.COLOR_RGB2BGR)
            composite = np.vstack([top_bgr, bottom_bgr])

            # Save high-quality JPEG
            output_file = os.path.join(composites_dir, f"composite_{i:02d}.jpg")
            cv2.imwrite(output_file, composite, [cv2.IMWRITE_JPEG_QUALITY, 95])
            composite_paths.append(output_file)

        except Exception as e:
            print(f"⚠️ Error processing frame {i}: {e}")
            # Create error placeholder
            error_image = np.zeros((frame.shape[0] * 2, frame.shape[1], 3), dtype=np.uint8)
            cv2.putText(
                error_image,
                f"Frame {i} processing error",
                (10, 40),
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                (255, 255, 255),
                2
            )
            output_file = os.path.join(composites_dir, f"composite_{i:02d}.jpg")
            cv2.imwrite(output_file, error_image)
            composite_paths.append(output_file)

    # Create ZIP archive
    print("\n📦 Creating ZIP archive...")
    zip_path = os.path.join(temp_dir, "dance_formations.zip")
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zip_file:
        for path in composite_paths:
            zip_file.write(path, arcname=os.path.basename(path))

    print("✅ Processing complete!")
    return composite_paths, zip_path, "✅ Processing complete! Download your results below."


In [16]:
#@title GUI Interface
# GUI - Gradio Interface

"""
Creates the web-based user interface with purple theme.
"""

def create_gradio_interface():
    """
    Create the Gradio web interface for the Dance Formation Tracker.

    This function builds a user-friendly web interface with:
    - Video upload capability
    - Parameter controls (dancers, model, frames)
    - Results gallery
    - Download functionality
    - Purple-themed styling
    """

    # Custom CSS for purple theme
    custom_css = """
    /* Main container styling */
    .gradio-container {
        font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
    }

    /* Purple gradient buttons */
    .gr-button-primary {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        border: none;
        color: white;
        font-weight: 600;
        padding: 12px 24px;
        font-size: 16px;
        transition: all 0.3s ease;
    }

    .gr-button-primary:hover {
        background: linear-gradient(135deg, #764ba2 0%, #667eea 100%);
        transform: translateY(-2px);
        box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
    }

    /* Input field styling */
    .gr-input {
        border-color: #667eea;
        transition: all 0.3s ease;
    }

    .gr-input:focus {
        border-color: #764ba2;
        box-shadow: 0 0 0 3px rgba(118, 75, 162, 0.1);
    }

    /* Heading styles */
    h1 {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        font-size: 2.5em;
        font-weight: 700;
        margin-bottom: 0.5em;
    }

    h2 {
        color: #764ba2;
        font-weight: 600;
    }

    h3 {
        color: #667eea;
        font-weight: 600;
    }

    /* Gallery styling */
    .gallery-container {
        border-radius: 8px;
        overflow: hidden;
    }

    /* Status box styling */
    .gr-textbox {
        border-left: 4px solid #667eea;
    }
    """

    # Create interface with purple theme
    with gr.Blocks(
        title="Dance Formation Tracker",
        theme=gr.themes.Soft(
            primary_hue="purple",
            secondary_hue="pink",
            neutral_hue="slate",
        ),
        css=custom_css
    ) as interface:

        # Header section
        gr.Markdown(
            """
            # 🩰 Dance Formation Tracker

            **Transform dance videos into clear formation diagrams using AI-powered tracking**

            This tool uses advanced computer vision to track dancers throughout a performance
            and generate bird's-eye view formation diagrams showing their positions on stage.
            """
        )

        # Main content in two columns
        with gr.Row():
            # Left column: Input controls
            with gr.Column(scale=1):
                gr.Markdown("### 📤 Video Settings")

                video_input = gr.File(
                    label="Upload Dance Video",
                    file_types=["video"],
                    type="filepath"
                )

                dancers_input = gr.Number(
                    label="Number of Dancers",
                    value=8,
                    minimum=1,
                    maximum=20,
                    step=1,
                    info="How many dancers are performing in the video?"
                )

                frames_slider = gr.Slider(
                    minimum=5,
                    maximum=30,
                    value=10,
                    step=1,
                    label="Frame Sampling Interval",
                    info="Number of frames to analyze (more = detailed but slower)"
                )

                model_selector = gr.Radio(
                    choices=[
                        ("Fast (nano)", "n"),
                        ("Balanced (medium)", "m"),
                        ("Accurate (large)", "l")
                    ],
                    value="m",
                    label="Model Quality",
                    info="Trade-off between speed and accuracy"
                )

                process_button = gr.Button(
                    "🎭 Generate Formation Analysis",
                    variant="primary",
                    size="lg"
                )

                # Tips section
                gr.Markdown(
                    """
                    ---
                    **📝 Tips for Best Results:**
                    - Use videos with all dancers visible
                    - Ensure adequate lighting
                    - Full body shots work better
                    - Steady camera position preferred
                    - Higher frame counts = smoother tracking
                    """
                )

            # Right column: Results display
            with gr.Column(scale=2):
                gr.Markdown("### 📊 Analysis Results")

                gallery_output = gr.Gallery(
                    label="Formation Analysis (Top: Original | Bottom: Stage Formation)",
                    columns=3,
                    rows=2,
                    height="auto",
                    show_label=True,
                    elem_classes="gallery-container"
                )

                with gr.Row():
                    download_output = gr.File(
                        label="📥 Download All Results (ZIP)",
                        visible=True
                    )

                    status_output = gr.Textbox(
                        label="Status",
                        lines=1,
                        interactive=False,
                        value="Ready to process your video..."
                    )

        # Connect processing function
        def run_analysis(video, dancers, frames, model):
            """Wrapper function to run the analysis."""
            return process_dance_video(
                video_path=video,
                expected_dancers=dancers,
                model_size=model,
                num_frames=int(frames)
            )

        process_button.click(
            fn=run_analysis,
            inputs=[video_input, dancers_input, frames_slider, model_selector],
            outputs=[gallery_output, download_output, status_output]
        )

        # Information section
        gr.Markdown(
            """
            ---
            ### 💡 How It Works

            1. **Upload** your dance video file
            2. **Specify** the number of dancers performing
            3. **Choose** analysis settings (frame interval and model quality)
            4. **Click** Generate to start processing
            5. **Review** the formation diagrams for each sampled frame
            6. **Download** all results as a ZIP file

            The system uses state-of-the-art pose estimation to:
            - Track individual dancers throughout the video
            - Map their positions to a virtual stage grid
            - Maintain consistent ID assignment across frames
            - Handle challenging conditions like occlusion and poor lighting

            ---
            *Built with YOLOv8, OpenCV, and Gradio*
            """
        )

    return interface


In [17]:
#@title Execute Code
# Run

"""
Main entry point to create and launch the web interface.
Run this cell to start the application.
"""

if __name__ == "__main__":
    print("🚀 Starting Dance Formation Tracker...")
    print("=" * 50)

    # Create the interface
    app = create_gradio_interface()

    # Launch with inline display for Colab
    app.launch(
        share=True,           # Still create public URL as backup
        inline=True,          # Display interface inline in notebook
        inbrowser=False,      # Don't try to open a new browser tab
        show_error=True,      # Show detailed error messages
        quiet=False,          # Show launch information
        height=800,           # Set height of inline frame (adjust as needed)
    )

    print("=" * 50)
    print("✅ Application is running!")
    print("📱 Access the app using the public URL above")
    print("🔗 Look for the public URL that starts with https://...gradio.live")
    print("⚠️ If you need to restart, first stop the cell (interrupt execution) then run again")

🚀 Starting Dance Formation Tracker...
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://ae89941bbfeff47a41.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


✅ Application is running!
📱 Access the app using the public URL above
🔗 Look for the public URL that starts with https://...gradio.live
⚠️ If you need to restart, first stop the cell (interrupt execution) then run again
