In [None]:
"""
Two-Stream Badminton Shot Classification
Vision Stream (R(2+1)D) + Shuttle Stream (1D CNN) → Late Fusion

Architecture:
┌─────────────────┐    ┌──────────────────┐
│ Video Clip      │    │ Shuttle Features │
│ (T, H, W, 3)    │    │ (T, F)           │
└────────┬────────┘    └────────┬─────────┘
         │                      │
    ┌────▼────┐            ┌────▼────┐
    │ R(2+1)D │            │ 1D CNN  │
    │ Vision  │            │ Temporal│
    └────┬────┘            └────┬────┘
         │                      │
    ┌────▼────┐            ┌────▼────┐
    │ 512-dim │            │ 128-dim │
    └────┬────┘            └────┬────┘
         │                      │
         └──────────┬───────────┘
                    │
              ┌─────▼─────┐
              │   Fusion  │
              │   Layer   │
              └─────┬─────┘
                    │
              ┌─────▼─────┐
              │ Classifier│
              └───────────┘

Shuttle Features (per frame):
- Position (x, y)
- Velocity (vx, vy, speed)
- Acceleration (ax, ay)
- Direction (angle, angle_change)
- Height change (dy/dt)
"""
pass

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip uninstall -q numpy opencv-python opencv-python-headless

In [None]:
!pip install -U numpy==2.3.0 lightning-thunder thinc numba opencv-python opencv-python-headless

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import pandas as pd
import csv
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import json
import random, os

In [None]:
def seed_all(seed=1023):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_all(2310)

# **1. SHUTTLE FEATURE EXTRACTION**

In [None]:
@dataclass
class ShuttleFeatureConfig:
    """Configuration for shuttle feature extraction"""
    smooth_window: int = 5  # smoothing for coordinates


class ShuttleFeatureExtractor:
    """Extract temporal features from shuttle trajectory"""

    def __init__(self, config: ShuttleFeatureConfig):
        self.cfg = config

    def load_shuttle_csv(self, csv_path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Load shuttle CSV (Frame, Visibility, X, Y)"""
        frames, vis, xs, ys = [], [], [], []
        with open(csv_path, 'r') as f:
            reader = csv.DictReader(f)
            for r in reader:
                frames.append(int(r["Frame"]))
                vis.append(int(r["Visibility"]))
                xs.append(float(r["X"]))
                ys.append(float(r["Y"]))

        frames = np.array(frames)
        vis = np.array(vis)
        xs = np.array(xs, dtype=np.float32)
        ys = np.array(ys, dtype=np.float32)

        idx = np.argsort(frames)
        return frames[idx], vis[idx], xs[idx], ys[idx]

    def _smooth(self, arr: np.ndarray) -> np.ndarray:
        """Moving average smoothing"""
        if self.cfg.smooth_window < 2:
            return arr
        window = self.cfg.smooth_window
        if window % 2 == 0:
            window += 1
        pad = window // 2
        padded = np.pad(arr, pad, mode='edge')
        kernel = np.ones(window) / window
        return np.convolve(padded, kernel, mode='valid').astype(np.float32)

    def _fill_invisible(self, xs: np.ndarray, ys: np.ndarray, vis: np.ndarray):
        """Interpolate invisible points"""
        valid = (vis == 1) & ~((xs == 0) & (ys == 0))
        xs_filled = xs.copy()
        ys_filled = ys.copy()

        if not valid.any():
            return xs_filled, ys_filled

        # Forward fill
        for i in range(1, len(xs)):
            if not valid[i]:
                xs_filled[i] = xs_filled[i-1]
                ys_filled[i] = ys_filled[i-1]

        # Backward fill
        for i in range(len(xs)-2, -1, -1):
            if not valid[i]:
                xs_filled[i] = xs_filled[i+1]
                ys_filled[i] = ys_filled[i+1]

        return xs_filled, ys_filled

    def extract_features_for_window(
        self,
        frames: np.ndarray,
        vis: np.ndarray,
        xs: np.ndarray,
        ys: np.ndarray,
        center_frame: int,
        window_size: int
    ) -> np.ndarray:
        """
        Extract shuttle features for a temporal window.

        Returns: (window_size, num_features) array
        Features per frame: [x, y, vx, vy, speed, ax, ay, direction, dir_change, height_change]
        """
        # Find frame indices
        start_frame = center_frame - window_size // 2
        end_frame = start_frame + window_size

        # Get indices in arrays
        frame_mask = (frames >= start_frame) & (frames < end_frame)

        if not frame_mask.any():
            # Return zero features if no data
            return np.zeros((window_size, 10), dtype=np.float32)

        # Extract window data
        frames_win = frames[frame_mask]
        vis_win = vis[frame_mask]
        xs_win = xs[frame_mask]
        ys_win = ys[frame_mask]

        # Fill and smooth
        xs_filled, ys_filled = self._fill_invisible(xs_win, ys_win, vis_win)
        xs_smooth = self._smooth(xs_filled)
        ys_smooth = self._smooth(ys_filled)

        # Compute derivatives
        vx = np.diff(xs_smooth, prepend=xs_smooth[0])
        vy = np.diff(ys_smooth, prepend=ys_smooth[0])
        speed = np.sqrt(vx**2 + vy**2)

        ax = np.diff(vx, prepend=vx[0])
        ay = np.diff(vy, prepend=vy[0])

        direction = np.arctan2(vy, vx)
        direction_change = np.diff(direction, prepend=direction[0])

        height_change = vy  # positive = downward (assuming y+ is down)

        # Stack features
        features = np.stack([
            xs_smooth, ys_smooth,
            vx, vy, speed,
            ax, ay,
            direction, direction_change,
            height_change
        ], axis=1)  # (T, 10)

        # Pad or truncate to exact window_size
        if len(features) < window_size:
            padding = np.zeros((window_size - len(features), 10), dtype=np.float32)
            features = np.vstack([features, padding])
        elif len(features) > window_size:
            features = features[:window_size]

        return features

# **2. DATASET (Two-Stream)**

In [None]:
@dataclass
class ClipExtractionConfig:
    """Configuration for clip extraction"""
    pad_frames: int = 15  # frames before/after contact (0.5s at 30fps)
    clip_size: Tuple[int, int] = (112, 112)  # spatial size for model input
    bbox_expansion: float = 0.3  # expand bbox by 30% to include racket/motion
    min_bbox_size: int = 50  # minimum bbox dimension

class ClipExtractor:
    """Extract player-centered clips from video given tracks and contacts"""

    def __init__(self, config: ClipExtractionConfig):
        self.cfg = config

    def load_tracks(self, track_csv_path: str) -> pd.DataFrame:
        """Load player tracks CSV"""
        df = pd.read_csv(track_csv_path)
        return df

    def get_player_bbox(
        self,
        tracks_df: pd.DataFrame,
        frame: int,
        player_id: int
    ) -> Optional[Tuple[int, int, int, int]]:
        """Get bbox for player at specific frame"""
        row = tracks_df[(tracks_df['frame'] == frame) & (tracks_df['id'] == player_id)]
        if len(row) == 0:
            return None
        row = row.iloc[0]
        return (int(row['x1']), int(row['y1']), int(row['x2']), int(row['y2']))

    def expand_bbox(
        self,
        bbox: Tuple[int, int, int, int],
        img_shape: Tuple[int, int]
    ) -> Tuple[int, int, int, int]:
        """Expand bbox to include motion/racket"""
        x1, y1, x2, y2 = bbox
        h, w = img_shape

        # Expand
        bw, bh = x2 - x1, y2 - y1
        dx = int(bw * self.cfg.bbox_expansion)
        dy = int(bh * self.cfg.bbox_expansion)

        x1 = max(0, x1 - dx)
        y1 = max(0, y1 - dy)
        x2 = min(w, x2 + dx)
        y2 = min(h, y2 + dy)

        return (x1, y1, x2, y2)

    def interpolate_missing_bbox(
        self,
        tracks_df: pd.DataFrame,
        frame: int,
        player_id: int,
        search_window: int = 5
    ) -> Optional[Tuple[int, int, int, int]]:
        """Interpolate bbox if missing at exact frame"""
        # Try exact frame first
        bbox = self.get_player_bbox(tracks_df, frame, player_id)
        if bbox is not None:
            return bbox

        # Search nearby frames
        for offset in range(1, search_window + 1):
            # Try before
            bbox = self.get_player_bbox(tracks_df, frame - offset, player_id)
            if bbox is not None:
                return bbox
            # Try after
            bbox = self.get_player_bbox(tracks_df, frame + offset, player_id)
            if bbox is not None:
                return bbox

        return None

    def extract_clip(
        self,
        video_path: str,
        tracks_df: pd.DataFrame,
        contact_frame: int,
        player_id: int
    ) -> Optional[np.ndarray]:
        """
        Extract clip for one player around contact frame.

        Returns: (T, H, W, C) array or None if failed
        """
        cap = cv2.VideoCapture(video_path)

        start_frame = max(0, contact_frame - self.cfg.pad_frames)
        end_frame = contact_frame + self.cfg.pad_frames + 1

        frames = []
        for frame_idx in range(start_frame, end_frame):
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret:
                break

            # Get bbox for this player at this frame
            bbox = self.interpolate_missing_bbox(tracks_df, frame_idx, player_id)
            if bbox is None:
                # Use previous bbox if available
                if len(frames) > 0:
                    bbox = self._last_bbox
                else:
                    continue

            self._last_bbox = bbox

            # Expand and crop
            bbox_exp = self.expand_bbox(bbox, frame.shape[:2])
            x1, y1, x2, y2 = bbox_exp

            # Check minimum size
            if (x2 - x1) < self.cfg.min_bbox_size or (y2 - y1) < self.cfg.min_bbox_size:
                continue

            cropped = frame[y1:y2, x1:x2]

            # Resize to fixed size
            resized = cv2.resize(cropped, self.cfg.clip_size)
            frames.append(resized)

        cap.release()

        if len(frames) < (self.cfg.pad_frames * 2 + 1) * 0.8:  # at least 80% frames
            return None

        # Pad if needed
        target_len = self.cfg.pad_frames * 2 + 1
        while len(frames) < target_len:
            frames.append(frames[-1])  # repeat last frame

        return np.stack(frames[:target_len], axis=0)  # (T, H, W, C)

    def extract_all_clips_for_contact(
        self,
        video_path: str,
        tracks_df: pd.DataFrame,
        contact_frame: int
    ) -> Dict[int, np.ndarray]:
        """
        Extract clips for ALL players around a contact frame.

        Returns: {player_id: clip_array}
        """
        player_ids = tracks_df['id'].unique()
        clips = {}

        for pid in player_ids:
            clip = self.extract_clip(video_path, tracks_df, contact_frame, pid)
            if clip is not None:
                clips[int(pid)] = clip

        return clips

In [None]:
class TwoStreamDataset(Dataset):
    """
    Two-stream dataset: Video clips + Shuttle features

    Each sample contains:
    - Video clip: (T, H, W, C)
    - Shuttle features: (T, F)
    - Label: shot class
    """

    def __init__(
        self,
        video_paths: List[str],
        tracks_csv_paths: List[str],
        shuttle_csv_paths: List[str],
        contact_frames: Dict[str, List[int]],
        labels: Dict[str, Dict[str, Dict[int, str]]],
        clip_extractor,
        shuttle_extractor: ShuttleFeatureExtractor,
        shot_classes: List[str],
        clip_cfg: ClipExtractionConfig,
        augment: bool = True
    ):
        self.clip_extractor = clip_extractor
        self.shuttle_extractor = shuttle_extractor
        self.shot_classes = shot_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(shot_classes)}
        self.augment = augment
        self.clip_cfg = clip_cfg

        # Cache shuttle data
        self.shuttle_cache = {}
        for shuttle_csv in shuttle_csv_paths:
            video_name = shuttle_csv.split('/')[-1].replace('_shuttle', '').replace('.csv', '.mp4')
            frames, vis, xs, ys = shuttle_extractor.load_shuttle_csv(shuttle_csv)
            self.shuttle_cache[video_name] = (frames, vis, xs, ys)

        # Build samples
        self.samples = []
        for video_path, tracks_csv in zip(video_paths, tracks_csv_paths):
            video_name = video_path.split('/')[-1]
            if video_name not in contact_frames:
                print(f"video_name {video_name} not found in contact_frames {contact_frames}.\n")
                continue

            for cf in contact_frames[video_name]:
                cf_key = f"contact_{cf}"

                if video_name not in labels:
                    print(f"video name {video_name} not in labels.\n")
                    continue

                if cf_key not in labels[video_name]:
                    print(f"contact frame key {cf_key} not in {labels}[{video_name}].\n")
                    continue

                player_labels = labels[video_name][cf_key]
                for player_id, shot_label in player_labels.items():
                    if shot_label not in self.class_to_idx:
                        continue

                    self.samples.append((
                        video_path,
                        tracks_csv,
                        video_name,
                        cf,
                        int(player_id),
                        self.class_to_idx[shot_label]
                    ))

        print(f"Two-Stream Dataset: {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, tracks_csv, video_name, contact_frame, player_id, label_idx = self.samples[idx]

        # 1. Extract video clip
        tracks_df = pd.read_csv(tracks_csv)
        clip = self.clip_extractor.extract_clip(video_path, tracks_df, contact_frame, player_id)

        if clip is None:
            return self.__getitem__((idx + 1) % len(self))

        # 2. Extract shuttle features
        if video_name in self.shuttle_cache:
            frames, vis, xs, ys = self.shuttle_cache[video_name]
            shuttle_features = self.shuttle_extractor.extract_features_for_window(
                frames, vis, xs, ys, contact_frame,
                window_size=self.clip_cfg.pad_frames * 2 + 1
            )
        else:
            # Fallback: zero features
            shuttle_features = np.zeros((self.clip_cfg.pad_frames * 2 + 1, 10), dtype=np.float32)

        # 3. Augmentation (apply same to both streams)
        if self.augment:
            clip, shuttle_features = self._augment(clip, shuttle_features)

        # 4. Normalize and convert to tensors
        clip = clip.astype(np.float32) / 255.0
        clip_tensor = torch.from_numpy(clip).permute(3, 0, 1, 2)  # (C, T, H, W)

        shuttle_tensor = torch.from_numpy(shuttle_features).T  # (F, T)

        label = torch.tensor(label_idx, dtype=torch.long)

        return clip_tensor, shuttle_tensor, label

    def _augment(self, clip: np.ndarray, shuttle_features: np.ndarray):
        """Synchronized augmentation"""
        # Horizontal flip (also flip shuttle x-coordinates and vx)
        if np.random.rand() > 0.5:
            clip = np.flip(clip, axis=2).copy()
            # Flip x, vx, ax (indices 0, 2, 5)
            shuttle_features[:, 0] *= -1  # x
            shuttle_features[:, 2] *= -1  # vx
            shuttle_features[:, 5] *= -1  # ax

        # Brightness (video only)
        if np.random.rand() > 0.5:
            factor = np.random.uniform(0.8, 1.2)
            clip = np.clip(clip * factor, 0, 255)

        return clip, shuttle_features

# **3. MODEL: Two-Stream Architecture**

In [None]:
class SpatioTemporalConv(nn.Module):
    """R(2+1)D block"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()

        self.spatial = nn.Conv3d(
            in_channels, out_channels,
            kernel_size=(1, kernel_size, kernel_size),
            stride=(1, stride, stride),
            padding=(0, padding, padding),
            bias=False
        )
        self.bn_spatial = nn.BatchNorm3d(out_channels)

        self.temporal = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=(kernel_size, 1, 1),
            stride=(stride, 1, 1),
            padding=(padding, 0, 0),
            bias=False
        )
        self.bn_temporal = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        x = F.relu(self.bn_spatial(self.spatial(x)))
        x = F.relu(self.bn_temporal(self.temporal(x)))
        return x

In [None]:
class VisionStream(nn.Module):
    """R(2+1)D for video understanding"""
    def __init__(self, input_channels: int = 3, hidden_dim: int = 512):
        super().__init__()

        self.stem = SpatioTemporalConv(input_channels, 64, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        self.layer1 = self._make_layer(64, 64, num_blocks=2)
        self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
        self.layer4 = self._make_layer(256, hidden_dim, num_blocks=2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))

    def _make_layer(self, in_channels, out_channels, num_blocks, stride=1):
        layers = []
        layers.append(SpatioTemporalConv(in_channels, out_channels, 3, stride, 1))
        for _ in range(1, num_blocks):
            layers.append(SpatioTemporalConv(out_channels, out_channels, 3, 1, 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x: (B, C, T, H, W)
        x = self.stem(x)
        x = self.pool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        return x.view(x.size(0), -1)  # (B, hidden_dim)

In [None]:
class ShuttleStream(nn.Module):
    """1D CNN for shuttle trajectory features"""
    def __init__(self, input_features: int = 10, hidden_dim: int = 128):
        super().__init__()

        # Temporal convolutions
        self.conv1 = nn.Conv1d(input_features, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)

        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)

        self.conv3 = nn.Conv1d(128, hidden_dim, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(hidden_dim)

        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        # x: (B, F, T)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        return x.squeeze(-1)  # (B, hidden_dim)

In [None]:
class TwoStreamFusionModel(nn.Module):
    """
    Two-stream fusion model for shot classification

    Combines visual appearance and shuttle trajectory
    """
    def __init__(
        self,
        num_classes: int,
        vision_dim: int = 512,
        shuttle_dim: int = 128,
        fusion_dim: int = 256,
        dropout: float = 0.5
    ):
        super().__init__()

        self.vision_stream = VisionStream(input_channels=3, hidden_dim=vision_dim)
        self.shuttle_stream = ShuttleStream(input_features=10, hidden_dim=shuttle_dim)

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(vision_dim + shuttle_dim, fusion_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_dim, fusion_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Classifier
        self.classifier = nn.Linear(fusion_dim, num_classes)

    def forward(self, video, shuttle_features):
        # Extract features from both streams
        vision_feat = self.vision_stream(video)  # (B, vision_dim)
        shuttle_feat = self.shuttle_stream(shuttle_features)  # (B, shuttle_dim)

        # Concatenate and fuse
        combined = torch.cat([vision_feat, shuttle_feat], dim=1)  # (B, vision_dim + shuttle_dim)
        fused = self.fusion(combined)  # (B, fusion_dim)

        # Classify
        logits = self.classifier(fused)  # (B, num_classes)

        return logits

# **4. TRAINING**

In [None]:
class TwoStreamTrainer:
    """Trainer for two-stream model"""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        lr: float = 1e-3,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.model = model.to(device)
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=50
        )

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch_idx, (video, shuttle, target) in enumerate(self.train_loader):
            video = video.to(self.device)
            shuttle = shuttle.to(self.device)
            target = target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(video, shuttle)
            loss = self.criterion(output, target)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

        return total_loss / len(self.train_loader), correct / total

    def validate(self):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for video, shuttle, target in self.val_loader:
                video = video.to(self.device)
                shuttle = shuttle.to(self.device)
                target = target.to(self.device)

                output = self.model(video, shuttle)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)

        return total_loss / len(self.val_loader), correct / total

    def train(self, epochs: int):
        best_val_acc = 0

        for epoch in range(epochs):
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()
            self.scheduler.step()

            print(f"Epoch {epoch+1}/{epochs} : Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} .. | .. Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\n")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(self.model.state_dict(), 'best_twostream_model.pth')
                print(f"  ✓ Saved best model (val_acc: {val_acc:.4f})")

.

.

.

.

.

.

In [None]:
shot_classes = [
    "block", "clear", "cross_net",
    "drive", "drop", "jump_smash",
    "lift", "push", "serve",
    "smash", "straight_net", "tap"
]

In [None]:
# 1. Setup extractors

clip_cfg = ClipExtractionConfig(pad_frames=15, clip_size=(112, 112))
clip_extractor = ClipExtractor(clip_cfg)

shuttle_cfg = ShuttleFeatureConfig(smooth_window=5)
shuttle_extractor = ShuttleFeatureExtractor(shuttle_cfg)

In [None]:
# 2. Load data

video_paths = [
    "data/videos/shi_vit_rally_1.mp4",
    "data/videos/shi_vit_rally_2.mp4",
    "data/videos/shi_vit_rally_3.mp4"
]

tracks_paths = [
    "data/player_tracks/shi_vit_rally_1_tracks.csv",
    "data/player_tracks/shi_vit_rally_2_tracks.csv",
    "data/player_tracks/shi_vit_rally_3_tracks.csv"
]

shuttle_paths = [
    "data/shuttle_tracks/shi_vit_rally_1_ball.csv",
    "data/shuttle_tracks/shi_vit_rally_2_ball.csv",
    "data/shuttle_tracks/shi_vit_rally_3_ball.csv"
]

contact_frames = {
    "shi_vit_rally_1.mp4": [
        94, 114, 137, 159, 192,
        208, 231, 263, 278, 305,
        329, 353, 382, 409, 445,
        460, 486, 514, 527, 555,
        581, 606, 666, 688, 730
    ],
    "shi_vit_rally_2.mp4": [
        62, 86, 107, 130, 172,
        198, 226, 257, 286, 337,
        347, 367, 379, 417, 450,
        471, 502, 539, 566, 591,
        628, 663, 678, 710, 754,
        767, 795, 829, 850, 886,
        901, 940, 968, 989
    ],
    "shi_vit_rally_3.mp4": [
        19, 42, 83, 96, 122,
        148, 172, 203, 215, 249,
        264, 286, 309, 349, 379,
        396, 428
    ]
}

labels = {
    "shi_vit_rally_1.mp4": {
        "contact_94": {"1": "serve", "2": "negative"},
        "contact_114": {"1": "negative", "2": "cross_net"},
        "contact_137": {"1": "cross_net", "2": "negative"},
        "contact_159": {"1": "negative", "2": "lift"},
        "contact_192": {"1": "drop", "2": "negative"},
        "contact_208": {"1": "negative", "2": "push"},
        "contact_231": {"1": "lift", "2": "negative"},
        "contact_263": {"1": "negative", "2": "drop"},
        "contact_278": {"1": "push", "2": "negative"},
        "contact_305": {"1": "negative", "2": "push"},
        "contact_329": {"1": "straight_net", "2": "negative"},
        "contact_353": {"1": "negative", "2": "cross_net"},
        "contact_382": {"1": "lift", "2": "negative"},
        "contact_409": {"1": "negative", "2": "clear"},
        "contact_445": {"1": "jump_smash", "2": "negative"},
        "contact_460": {"1": "negative", "2": "block"},
        "contact_486": {"1": "lift", "2": "negative"},
        "contact_514": {"1": "negative", "2": "smash"},
        "contact_527": {"1": "block", "2": "negative"},
        "contact_555": {"1": "negative", "2": "cross_net"},
        "contact_581": {"1": "straight_net", "2": "negative"},
        "contact_606": {"1": "negative", "2": "lift"},
        "contact_666": {"1": "drop", "2": "negative"},
        "contact_688": {"1": "negative", "2": "lift"},
        "contact_730": {"1": "drive", "2": "negative"}
    }
}

In [None]:
# 3. Create dataset

dataset = TwoStreamDataset(
    video_paths, tracks_paths, shuttle_paths,
    contact_frames, labels,
    clip_extractor, shuttle_extractor,
    shot_classes, clip_cfg, augment=True
)

train_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=2)

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name shi_vit_rally_2.mp4 not in labels.

video name sh

In [None]:
# 4. Create model

model = TwoStreamFusionModel(
    num_classes=len(shot_classes),
    vision_dim=512,
    shuttle_dim=128,
    fusion_dim=256
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 7,156,748


In [None]:
# 5. Train

trainer = TwoStreamTrainer(model, train_loader, val_loader, lr=1e-3)
trainer.train(epochs=50)

Epoch 1/50 : Train Loss: 2.3785, Train Acc: 0.1200 .. | .. Val Loss: 2.4070, Val Acc: 0.2400

  ✓ Saved best model (val_acc: 0.2400)
Epoch 2/50 : Train Loss: 2.3957, Train Acc: 0.2000 .. | .. Val Loss: 2.4856, Val Acc: 0.2400

Epoch 3/50 : Train Loss: 2.7000, Train Acc: 0.1600 .. | .. Val Loss: 5.3684, Val Acc: 0.2400

Epoch 4/50 : Train Loss: 2.2383, Train Acc: 0.2400 .. | .. Val Loss: 2.4033, Val Acc: 0.2400

Epoch 5/50 : Train Loss: 2.6668, Train Acc: 0.2000 .. | .. Val Loss: 2.5562, Val Acc: 0.2400

Epoch 6/50 : Train Loss: 2.2048, Train Acc: 0.1200 .. | .. Val Loss: 2.3230, Val Acc: 0.2400

Epoch 7/50 : Train Loss: 2.7638, Train Acc: 0.1600 .. | .. Val Loss: 2.3577, Val Acc: 0.2400

Epoch 8/50 : Train Loss: 2.0411, Train Acc: 0.2400 .. | .. Val Loss: 3.0675, Val Acc: 0.2400

Epoch 9/50 : Train Loss: 2.5960, Train Acc: 0.2400 .. | .. Val Loss: 3.3220, Val Acc: 0.2400

Epoch 10/50 : Train Loss: 2.3080, Train Acc: 0.2000 .. | .. Val Loss: 2.9196, Val Acc: 0.2400

Epoch 11/50 : Train 

KeyboardInterrupt: 

In [None]:
!pip show opencv-python-headless

Name: opencv-python-headless
Version: 4.12.0.88
Summary: Wrapper package for OpenCV python bindings.
Home-page: https://github.com/opencv/opencv-python
Author: 
Author-email: 
License: Apache 2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: numpy
Required-by: albucore, albumentations
