In [2]:
import os, json, math, glob
from pathlib import Path
from typing import Dict, Tuple, List, Optional
import cv2
import pandas as pd
import numpy as np

In [13]:
# -----------------------------
# User paths (EDIT THESE)
# -----------------------------
JSON_PATH     = "/content/shi_vit_labels.json"     # your JSON like the one you pasted
VIDEOS_DIR    = "/content/videos"            # folder containing the mp4s
TRACKS_DIR    = "/content/tracks"            # folder containing *_tracks.csv
OUTPUT_ROOT   = "/content/drive/MyDrive/FT3163,3164/SlowFast/05_clips/finals_v1"      # clips will be saved under OUTPUT_ROOT/<label>/*.mp4

# Optional: override for weirdly named track CSVs.
# Keys = video filename in JSON, value = exact CSV path.
TRACKS_MAP_OVERRIDE: Dict[str, str] = {
    # "shi_vit_rally_1.mp4": "/content/tracks/shi_vit_rally_1_tracks.csv",
}

# -----------------------------
# Config (tweak as you like)
# -----------------------------
N_BEFORE_AFTER = 15        # n frames before and after contact (total = 2n+1)
OUT_SIDE       = 256       # output square crop size (e.g., 224 or 256)
PAD_RATIO      = 1.3      # enlarge union bbox by this ratio
INCLUDE_LABELS = None      # None → include all non-"negative"; or set like {"smash","drop",...}
SKIP_LABEL     = None       # "negative"
FOURCC         = "mp4v"    # "mp4v" is usually safe in Colab
VERBOSE        = True

In [7]:
# -----------------------------
# Helpers
# -----------------------------

def resolve_tracks_csv(video_name: str) -> Optional[str]:
    """Find the tracks CSV for a given video."""
    if video_name in TRACKS_MAP_OVERRIDE:
        return TRACKS_MAP_OVERRIDE[video_name]
    stem = Path(video_name).stem
    # Common patterns
    candidates = [
        f"{stem}_tracks.csv",
        f"{stem}-tracks.csv",
        f"{stem}.csv",
    ]
    for pat in candidates:
        p = Path(TRACKS_DIR) / pat
        if p.exists():
            return str(p)
    # as last resort: glob anything that starts with stem
    glob_hits = glob.glob(str(Path(TRACKS_DIR) / f"{stem}*.csv"))
    return glob_hits[0] if glob_hits else None

def load_tracks_df(csv_path: str) -> pd.DataFrame:
    """
    Expected columns:
      frame, id, x1, y1, x2, y2, conf, cls
    """
    df = pd.read_csv(csv_path)
    # normalize types
    df["frame"] = df["frame"].astype(int)
    df["id"]    = df["id"].astype(int)
    # Ensure bbox as ints
    for c in ("x1","y1","x2","y2"):
        df[c] = df[c].astype(float)
    return df

def frame_bbox_lookup(df: pd.DataFrame, player_id: int) -> Dict[int, Tuple[float,float,float,float]]:
    """Build fast lookup dict: frame -> (x1,y1,x2,y2) for a given player_id."""
    sub = df[df["id"] == player_id][["frame","x1","y1","x2","y2"]]
    return {int(r.frame): (float(r.x1), float(r.y1), float(r.x2), float(r.y2)) for r in sub.itertuples(index=False)}

def union_bbox(bboxes: List[Tuple[float,float,float,float]]) -> Tuple[float,float,float,float]:
    xs1 = [b[0] for b in bboxes]; ys1 = [b[1] for b in bboxes]
    xs2 = [b[2] for b in bboxes]; ys2 = [b[3] for b in bboxes]
    return (min(xs1), min(ys1), max(xs2), max(ys2))

def expand_bbox(b: Tuple[float,float,float,float], pad_ratio: float, W: int, H: int) -> Tuple[int,int,int,int]:
    x1,y1,x2,y2 = b
    cx = 0.5*(x1+x2); cy = 0.5*(y1+y2)
    w  = (x2-x1)*pad_ratio; h = (y2-y1)*pad_ratio
    nx1 = int(max(0, math.floor(cx - 0.5*w)))
    ny1 = int(max(0, math.floor(cy - 0.5*h)))
    nx2 = int(min(W-1, math.ceil (cx + 0.5*w)))
    ny2 = int(min(H-1, math.ceil (cy + 0.5*h)))
    # Ensure valid
    if nx2 <= nx1: nx2 = min(W-1, nx1+1)
    if ny2 <= ny1: ny2 = min(H-1, ny1+1)
    return nx1, ny1, nx2, ny2

def fixed_window_indices(center: int, n: int) -> List[int]:
    return list(range(center - n, center + n + 1))

def has_complete_window(indices: List[int], frame_count: int) -> bool:
    return indices[0] >= 0 and indices[-1] < frame_count

def nearest_bbox(ff: int, lookup: Dict[int, Tuple[float,float,float,float]], span: int=3) -> Optional[Tuple[float,float,float,float]]:
    """Try to recover a missing bbox by searching within +/- span frames."""
    if ff in lookup:
        return lookup[ff]
    for d in range(1, span+1):
        if ff - d in lookup:
            return lookup[ff - d]
        if ff + d in lookup:
            return lookup[ff + d]
    return None

In [5]:
def write_clip(
    video_path: str,
    frame_indices: List[int],
    crop_box: Tuple[int,int,int,int],
    out_path: str,
    out_side: int,
    fps: float
):
    os.makedirs(Path(out_path).parent, exist_ok=True)
    fourcc = cv2.VideoWriter_fourcc(*FOURCC)
    writer = cv2.VideoWriter(out_path, fourcc, float(fps), (out_side, out_side))
    cap = cv2.VideoCapture(video_path)

    for ff in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, ff)
        ok, frame = cap.read()
        if not ok:
            writer.release()
            cap.release()
            raise RuntimeError(f"Failed to read frame {ff} from {video_path}")
        x1,y1,x2,y2 = crop_box
        crop = frame[y1:y2, x1:x2]
        crop = cv2.resize(crop, (out_side, out_side), interpolation=cv2.INTER_AREA)
        writer.write(crop)

    writer.release()
    cap.release()

In [10]:
# -----------------------------
# Main extraction
# -----------------------------

def extract_all(
    json_path: str,
    videos_dir: str,
    tracks_dir: str,
    out_root: str,
    n_before_after: int = N_BEFORE_AFTER,
    out_side: int = OUT_SIDE,
    pad_ratio: float = PAD_RATIO,
    include_labels = INCLUDE_LABELS,
    skip_label: str = SKIP_LABEL,
    verbose: bool = VERBOSE
) -> pd.DataFrame:
    """
    Returns a manifest DataFrame of all saved clips.
    """
    with open(json_path, "r") as f:
        ann = json.load(f)

    records = []
    os.makedirs(out_root, exist_ok=True)

    for video_name, contacts in ann.items():
        video_path = str(Path(videos_dir) / video_name)
        if not Path(video_path).exists():
            print(f"[WARN] Video missing: {video_path} — skipping.")
            continue

        tracks_csv = resolve_tracks_csv(video_name)
        if tracks_csv is None or not Path(tracks_csv).exists():
            print(f"[WARN] Tracks CSV not found for {video_name} — skipping.")
            continue

        if verbose:
            print(f"\n[Video] {video_name}")
            print(f"  Video:  {video_path}")
            print(f"  Tracks: {tracks_csv}")

        # Load tracks
        df = load_tracks_df(tracks_csv)

        # Open video once to get meta
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps         = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0
        W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        cap.release()

        # Pre-lookup per player for speed
        lookup_by_player: Dict[int, Dict[int, Tuple[float,float,float,float]]] = {}
        for pid in df["id"].unique():
            lookup_by_player[int(pid)] = frame_bbox_lookup(df, int(pid))

        # Iterate contacts
        for key, per_player in contacts.items():
            # key = "contact_XXX"
            try:
                cframe = int(str(key).split("_")[-1])
            except:
                print(f"  [WARN] Bad contact key {key} — skipping")
                continue

            frame_indices = fixed_window_indices(cframe, n_before_after)
            if not has_complete_window(frame_indices, frame_count):
                if verbose:
                    print(f"  [skip] contact {cframe}: outside video bounds ({frame_indices[0]}..{frame_indices[-1]} not in [0..{frame_count-1}])")
                continue

            # per_player is like {"1": "serve", "2": "negative"}
            for pid_str, label in per_player.items():
                if label == skip_label:
                    continue
                if include_labels is not None and label not in include_labels:
                    continue

                try:
                    pid = int(pid_str)
                except:
                    if verbose:
                        print(f"  [WARN] bad player id '{pid_str}' at contact {cframe} — skipping")
                    continue

                # Gather bboxes over the window with small nearest fill
                lk = lookup_by_player.get(pid, {})
                bboxes = []
                missing = False
                for ff in frame_indices:
                    bb = nearest_bbox(ff, lk, span=3)
                    if bb is None:
                        missing = True
                        break
                    bboxes.append(bb)
                if missing:
                    if verbose:
                        print(f"  [skip] contact {cframe}, pid {pid}: missing bbox in window")
                    continue

                # Build a stable crop by union across the window + padding
                uni = union_bbox(bboxes)
                crop_box = expand_bbox(uni, pad_ratio, W, H)

                # Write out
                out_dir = Path(out_root) / str(label)
                clip_name = f"{Path(video_name).stem}_f{cframe:05d}_p{pid}.mp4"
                out_path = str(out_dir / clip_name)

                try:
                    write_clip(
                        video_path=video_path,
                        frame_indices=frame_indices,
                        crop_box=crop_box,
                        out_path=out_path,
                        out_side=out_side,
                        fps=fps
                    )
                except Exception as e:
                    print(f"  [ERR] Writing clip failed for {video_name} c{cframe} p{pid}: {e}")
                    continue

                if verbose:
                    print(f"  [+] {label:<12} {clip_name}")

                records.append({
                    "path": out_path,
                    "label": label,
                    "video": video_name,
                    "player_id": pid,
                    "contact_frame": cframe,
                    "start_frame": frame_indices[0],
                    "end_frame": frame_indices[-1],
                    "fps": fps,
                    "width": out_side,
                    "height": out_side
                })

    manifest = pd.DataFrame.from_records(records)
    man_path = str(Path(out_root) / "manifest.csv")
    if len(manifest):
        manifest.to_csv(man_path, index=False)
        print(f"\nSaved manifest with {len(manifest)} clips → {man_path}")
    else:
        print("\nNo clips were produced.")
    return manifest

In [14]:
manifest = extract_all(
    json_path=JSON_PATH,
    videos_dir=VIDEOS_DIR,
    tracks_dir=TRACKS_DIR,
    out_root=OUTPUT_ROOT,
    n_before_after=N_BEFORE_AFTER,
    out_side=OUT_SIDE,
    pad_ratio=PAD_RATIO,
    include_labels=INCLUDE_LABELS,
    skip_label=SKIP_LABEL,
    verbose=VERBOSE
)
# Display basic stats
if len(manifest):
    print(manifest.groupby("label").size().sort_values(ascending=False))


[Video] shi_vit_rally_1.mp4
  Video:  /content/videos/shi_vit_rally_1.mp4
  Tracks: /content/tracks/shi_vit_rally_1_tracks.csv
  [+] serve        shi_vit_rally_1_f00094_p1.mp4
  [+] negative     shi_vit_rally_1_f00094_p2.mp4
  [+] negative     shi_vit_rally_1_f00114_p1.mp4
  [+] cross_net    shi_vit_rally_1_f00114_p2.mp4
  [+] cross_net    shi_vit_rally_1_f00137_p1.mp4
  [+] negative     shi_vit_rally_1_f00137_p2.mp4
  [+] negative     shi_vit_rally_1_f00159_p1.mp4
  [+] lift         shi_vit_rally_1_f00159_p2.mp4
  [+] drop         shi_vit_rally_1_f00192_p1.mp4
  [+] negative     shi_vit_rally_1_f00192_p2.mp4
  [+] negative     shi_vit_rally_1_f00208_p1.mp4
  [+] push         shi_vit_rally_1_f00208_p2.mp4
  [+] lift         shi_vit_rally_1_f00231_p1.mp4
  [+] negative     shi_vit_rally_1_f00231_p2.mp4
  [+] negative     shi_vit_rally_1_f00263_p1.mp4
  [+] drop         shi_vit_rally_1_f00263_p2.mp4
  [+] push         shi_vit_rally_1_f00278_p1.mp4
  [+] negative     shi_vit_rally_1_f002