In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import pandas as pd
import csv
import json
import os
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
from tqdm.auto import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
@dataclass
class ExtractionConfig:
    """Configuration for clip extraction"""
    pad_frames: int = 15
    clip_size: Tuple[int, int] = (112, 112)
    bbox_expansion: float = 0.3
    min_bbox_size: int = 50
    smooth_window: int = 5
    output_dir: str = "data/clips"


class ClipPreExtractor:
    """Pre-extract all clips and save to disk"""

    def __init__(self, config: ExtractionConfig):
        self.cfg = config
        Path(self.cfg.output_dir).mkdir(parents=True, exist_ok=True)

    def extract_all_clips(
        self,
        video_paths: List[str],
        tracks_csv_paths: List[str],
        shuttle_csv_paths: List[str],
        contact_frames: Dict[str, List[int]],
        labels: Dict[str, Dict[str, Dict[int, str]]]
    ) -> Dict[str, str]:
        """
        Extract all clips and save to disk.

        Returns: metadata dict {clip_id: label}
        """
        metadata = {}

        for video_path, tracks_csv, shuttle_csv in tqdm(
            zip(video_paths, tracks_csv_paths, shuttle_csv_paths),
            desc="Processing videos",
            total=len(video_paths)
        ):
            video_name = Path(video_path).name

            if video_name not in contact_frames:
                print(f"video_name {video_name} not found in contact_frames {contact_frames}.")
                continue

            # Load tracks and shuttle
            tracks_df = pd.read_csv(tracks_csv)
            print(tracks_df.head(5))
            shuttle_data = self._load_shuttle(shuttle_csv)

            # Process each contact frame
            for cf in tqdm(contact_frames[video_name], desc=f"  {video_name}", leave=False):
                cf_key = f"contact_{cf}"

                if video_name not in labels:
                    print(f"video name {video_name} not in labels.")
                    continue

                if cf_key not in labels[video_name]:
                    print(f"contact frame key {cf_key} not in {labels}[{video_name}].")
                    continue

                player_labels = labels[video_name][cf_key]

                # Extract for each player
                for player_id, shot_label in player_labels.items():
                    clip_id = f"{Path(video_name).stem}_frame{cf}_player{player_id}"

                    # Extract video clip
                    video_clip = self._extract_video_clip(
                        video_path, tracks_df, cf, player_id
                    )

                    if video_clip is None:
                        print(f"Unable to extract video clip for {player_id}, {shot_label}")
                        continue

                    # Extract shuttle features
                    shuttle_features = self._extract_shuttle_features(
                        shuttle_data, cf
                    )

                    # Save to disk
                    video_path_out = os.path.join(self.cfg.output_dir, f"{clip_id}.npy")
                    shuttle_path_out = os.path.join(self.cfg.output_dir, f"{clip_id}_shuttle.npy")

                    np.save(video_path_out, video_clip)
                    np.save(shuttle_path_out, shuttle_features)

                    # Store metadata
                    metadata[clip_id] = shot_label

        # Save metadata
        metadata_path = os.path.join(self.cfg.output_dir, "metadata.json")
        existing_metadata = {}
        if os.path.exists(metadata_path):
            try:
                with open(metadata_path, 'r') as f:
                    existing_metadata = json.load(f)
                    print(f"Loaded existing metadata.")
                if not isinstance(existing_metadata, dict):
                    print(f"Overwriting metadata.")
                    existing_metadata = {}
            except json.JSONDecodeError:
                print(f"Warning: {metadata_path} was corrupted. Overwriting.")
                existing_metadata = {}

        existing_metadata.update(metadata)

        with open(metadata_path, 'w') as f:
            json.dump(existing_metadata, f, indent=2)

        print(f"\n✓ Saved/Updated. Total clips in metadata.json: {len(existing_metadata)}")
        return existing_metadata

    def _load_shuttle(self, csv_path: str):
        """Load shuttle CSV"""
        frames, vis, xs, ys = [], [], [], []
        with open(csv_path, 'r') as f:
            reader = csv.DictReader(f)
            for r in reader:
                frames.append(int(r["Frame"]))
                vis.append(int(r["Visibility"]))
                xs.append(float(r["X"]))
                ys.append(float(r["Y"]))

        frames = np.array(frames)
        vis = np.array(vis)
        xs = np.array(xs, dtype=np.float32)
        ys = np.array(ys, dtype=np.float32)

        idx = np.argsort(frames)
        return frames[idx], vis[idx], xs[idx], ys[idx]

    def _extract_video_clip(
        self,
        video_path: str,
        tracks_df: pd.DataFrame,
        contact_frame: int,
        player_id: int
    ) -> Optional[np.ndarray]:
        """Extract video clip for one player"""
        cap = cv2.VideoCapture(video_path)

        start_frame = max(0, contact_frame - self.cfg.pad_frames)
        end_frame = contact_frame + self.cfg.pad_frames + 1
        target_len = self.cfg.pad_frames * 2 + 1

        frames = []
        last_bbox = None

        for frame_idx in range(start_frame, end_frame):
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret:
                print(f"Couldn't read frame {frame_idx} from video {video_path}")
                break

            # Get bbox
            # print(f"Getting bbox for player {player_id} at frame {frame_idx} from tracks_df)")
            bbox = self._get_bbox(tracks_df, int(frame_idx), int(player_id))
            if bbox is None:
                bbox = last_bbox
            if bbox is None:
                print(f"Couldn't get bounding box for player {player_id} at frame {frame_idx}")
                continue

            last_bbox = bbox

            # Expand and crop
            bbox_exp = self._expand_bbox(bbox, frame.shape[:2])
            x1, y1, x2, y2 = bbox_exp

            if (x2 - x1) < self.cfg.min_bbox_size or (y2 - y1) < self.cfg.min_bbox_size:
                continue

            cropped = frame[y1:y2, x1:x2]
            resized = cv2.resize(cropped, self.cfg.clip_size)
            # print(f"Appending frame for frame {frame_idx}")
            frames.append(resized)

        cap.release()

        if len(frames) < target_len * 0.8:
            print(f"len(frames) {len(frames)} < target_len {target_len}")
            return None

        # Pad if needed
        while len(frames) < target_len:
            frames.append(frames[-1])

        clip = np.stack(frames[:target_len], axis=0)  # (T, H, W, C)
        return clip.astype(np.uint8)

    def _extract_shuttle_features(
        self,
        shuttle_data: Tuple,
        contact_frame: int
    ) -> np.ndarray:
        """Extract shuttle features for window"""
        frames, vis, xs, ys = shuttle_data

        window_size = self.cfg.pad_frames * 2 + 1
        start_frame = contact_frame - self.cfg.pad_frames
        end_frame = contact_frame + self.cfg.pad_frames + 1

        # Find indices
        mask = (frames >= start_frame) & (frames < end_frame)

        if not mask.any():
            return np.zeros((window_size, 10), dtype=np.float32)

        # Extract window
        xs_win = xs[mask]
        ys_win = ys[mask]
        vis_win = vis[mask]

        # Fill invisible
        xs_filled, ys_filled = self._fill_invisible(xs_win, ys_win, vis_win)

        # Smooth
        xs_smooth = self._smooth(xs_filled)
        ys_smooth = self._smooth(ys_filled)

        # Compute features
        vx = np.diff(xs_smooth, prepend=xs_smooth[0])
        vy = np.diff(ys_smooth, prepend=ys_smooth[0])
        speed = np.sqrt(vx**2 + vy**2)

        ax = np.diff(vx, prepend=vx[0])
        ay = np.diff(vy, prepend=vy[0])

        direction = np.arctan2(vy, vx)
        direction_change = np.diff(direction, prepend=direction[0])

        features = np.stack([
            xs_smooth, ys_smooth,
            vx, vy, speed,
            ax, ay,
            direction, direction_change,
            vy  # height_change
        ], axis=1)

        # Pad/truncate
        if len(features) < window_size:
            padding = np.zeros((window_size - len(features), 10), dtype=np.float32)
            features = np.vstack([features, padding])
        elif len(features) > window_size:
            features = features[:window_size]

        return features.astype(np.float32)

    def _get_bbox(self, tracks_df, frame, player_id):
        """Get bbox for player at frame"""
        row = tracks_df[(tracks_df['frame'] == frame) & (tracks_df['id'] == player_id)]
        if len(row) == 0:
            return None
        row = row.iloc[0]
        return (int(row['x1']), int(row['y1']), int(row['x2']), int(row['y2']))

    def _expand_bbox(self, bbox, img_shape):
        """Expand bbox"""
        x1, y1, x2, y2 = bbox
        h, w = img_shape
        bw, bh = x2 - x1, y2 - y1
        dx = int(bw * self.cfg.bbox_expansion)
        dy = int(bh * self.cfg.bbox_expansion)
        return (max(0, x1 - dx), max(0, y1 - dy), min(w, x2 + dx), min(h, y2 + dy))

    def _fill_invisible(self, xs, ys, vis):
        """Fill invisible points"""
        valid = (vis == 1) & ~((xs == 0) & (ys == 0))
        xs_filled, ys_filled = xs.copy(), ys.copy()

        if not valid.any():
            return xs_filled, ys_filled

        for i in range(1, len(xs)):
            if not valid[i]:
                xs_filled[i] = xs_filled[i-1]
                ys_filled[i] = ys_filled[i-1]

        for i in range(len(xs)-2, -1, -1):
            if not valid[i]:
                xs_filled[i] = xs_filled[i+1]
                ys_filled[i] = ys_filled[i+1]

        return xs_filled, ys_filled

    def _smooth(self, arr):
        """Moving average"""
        if self.cfg.smooth_window < 2:
            return arr
        window = self.cfg.smooth_window
        if window % 2 == 0:
            window += 1
        pad = window // 2
        padded = np.pad(arr, pad, mode='edge')
        kernel = np.ones(window) / window
        return np.convolve(padded, kernel, mode='valid').astype(np.float32)

In [None]:
ROOT = "/content/drive/MyDrive/FIT3163,3164/TwoStreamDataset/data_v1"
# ROOT = "/content"

cfg = ExtractionConfig(
    pad_frames=15,
    clip_size=(112, 112),
    output_dir=f"{ROOT}/clips"
)

extractor = ClipPreExtractor(cfg)

video_paths = [
    # f"{ROOT}/shi_vit_rally_1.mp4",
    # f"{ROOT}/shi_vit_rally_2.mp4",
    f"{ROOT}/shi_vit_rally_3.mp4"
]

tracks_paths = [
    # f"{ROOT}/shi_vit_rally_1_tracks.csv",
    # f"{ROOT}/shi_vit_rally_2_tracks.csv",
    f"{ROOT}/shi_vit_rally_3_tracks.csv"
]

shuttle_paths = [
    # f"{ROOT}/shi_vit_rally_1_ball.csv",
    # f"{ROOT}/shi_vit_rally_2_ball.csv",
    f"{ROOT}/shi_vit_rally_3_ball.csv"
]

contact_frames = {
    "shi_vit_rally_1.mp4": [
        94, 114, 137, 159, 192,
        208, 231, 263, 278, 305,
        329, 353, 382, 409, 445,
        460, 486, 514, 527, 555,
        581, 606, 666, 688, 730
    ],
    "shi_vit_rally_2.mp4": [
        62, 86, 107, 130, 172,
        198, 226, 257, 286, 337,
        347, 367, 379, 417, 450,
        471, 502, 539, 566, 591,
        628, 663, 678, 710, 754,
        767, 795, 829, 850, 886,
        901, 940, 968, 989
    ],
    "shi_vit_rally_3.mp4": [
        19, 42, 83, 96, 122,
        148, 172, 203, 215, 249,
        264, 286, 309, 349, 379,
        396, 428
    ]
}

with open("shi_vit_labels.json", "r") as f:
    labels = json.load(f)

In [None]:
labels_data = {}

for video_path in video_paths:
    # Extract the filename (e.g., "shi_vit_rally_1.mp4") from the path
    video_name = os.path.basename(video_path)

    if video_name in contact_frames:
        frame_list = contact_frames[video_name]
        video_labels = {}

        for frame_num in frame_list:
            contact_key = f"contact_{frame_num}"

            # Set the default "negative" labels for players 1 and 2
            video_labels[contact_key] = {"1": "negative", "2": "negative"}

        labels_data[video_name] = video_labels
    else:
        print(f"Warning: No contact frames found for {video_name}, skipping.")


print("--- Generated JSON ---")

output_filename = "shi_vit_labels.json"
with open(output_filename, 'w') as f:
    json.dump(labels_data, f, indent=4)

print(f"\nSuccessfully generated and saved to {output_filename}")

--- Generated JSON ---

Successfully generated and saved to shi_vit_labels.json


In [None]:
metadata = extractor.extract_all_clips(
    video_paths, tracks_paths, shuttle_paths,
    contact_frames, labels
)

Processing videos:   0%|          | 0/1 [00:00<?, ?it/s]

   frame  id   x1   y1   x2   y2      conf  cls
0     11   1  678  216  775  368  0.910234    0
1     11   2  563  321  680  531  0.833823    0
2     12   1  678  216  775  368  0.909839    0
3     12   2  563  320  680  530  0.838796    0
4     13   1  678  216  774  368  0.908550    0


  shi_vit_rally_3.mp4:   0%|          | 0/17 [00:00<?, ?it/s]

Couldn't get bounding box for player 1 at frame 4
Couldn't get bounding box for player 1 at frame 5
Couldn't get bounding box for player 1 at frame 6
Couldn't get bounding box for player 1 at frame 7
Couldn't get bounding box for player 1 at frame 8
Couldn't get bounding box for player 1 at frame 9
Couldn't get bounding box for player 1 at frame 10
len(frames) 24 < target_len 31
Unable to extract video clip for 1, negative
Couldn't get bounding box for player 2 at frame 4
Couldn't get bounding box for player 2 at frame 5
Couldn't get bounding box for player 2 at frame 6
Couldn't get bounding box for player 2 at frame 7
Couldn't get bounding box for player 2 at frame 8
Couldn't get bounding box for player 2 at frame 9
Couldn't get bounding box for player 2 at frame 10
len(frames) 24 < target_len 31
Unable to extract video clip for 2, serve
Loaded existing metadata.

✓ Saved/Updated. Total clips in metadata.json: 98
