## **Build label map & splits**

In [1]:
!pip -q install torchmetrics decord fvcore pytorchvideo

[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_utilities-0.12.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/dill-0.3.9-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/opt_ei

In [1]:
import json, random, csv, glob, os
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassConfusionMatrix
import torch.nn.functional as F
from torchvision.transforms import v2
from decord import VideoReader, cpu
import torchvision
import numpy as np
from typing import Dict, Tuple, Optional, List
import pandas as pd
from tqdm import tqdm

from IPython.display import Video

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import random

def seed_all(seed=1023):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
seed_all(2310)

## **1. CONFIGURATION**
### This class centralizes all hyperparameters and file paths.

In [81]:
class Config:
    def __init__(self):
        self.root_dir = "."
        self.clips_dir = os.path.join(self.root_dir, "FINALS_DATASET/finals_v1")
        self.splits_dir = os.path.join(self.root_dir, "splits_finals_v1")
        self.models_dir = os.path.join(self.root_dir, "train_finals_v1")
        self.best_model_path = os.path.join(self.models_dir, "best.pt")

        # self.labels = [
        #     "smash", "jump_smash", "block",
        #     "drop", "clear", "lift", "drive",
        #     "straight_net", "cross_net", "serve",
        #     "push", "tap"
        # ]

        self.labels = [
            "block", "clear", "cross_net",
            "drive", "drop", "jump_smash",
            "lift", "push", "serve",
            "smash", "straight_net", "tap",
            "negative" 
        ]

        # Dataset parameters
        self.side = 256             # ori: 224, but realised dataset clips are cropped to 256!!
        self.slow_t = 8             # 8 frames for slow pathway
        self.alpha = 4             # ratio between fast and slow
        self.fast_t = self.slow_t * self.alpha
        self.fast_target = 256      # ori: 224

        self.batch_size = 8
        

# Create a configuration object
cfg = Config()

## **2. DATA PREPARATION**
### This function handles all logic for splitting and saving the dataset.

In [111]:
def prepare_data_splits(config: Config):
    """
    Finds video clips, shuffles them, and splits them into train, val, and test sets.
    Saves the splits as CSV files and the label map as a JSON file.
    """
    os.makedirs(config.splits_dir, exist_ok=True)
    os.makedirs(config.models_dir, exist_ok=True)

    labels_map = {lab: i for i, lab in enumerate(config.labels)}
    with open(os.path.join(config.splits_dir, "labels_map.json"), "w") as f:
        json.dump(labels_map, f, indent=2)

    items = []
    for label in config.labels:
        # Use glob to find all video files for the current label
        for clip_path in glob.glob(os.path.join(config.clips_dir, label, "*.mp4")):
            items.append((clip_path, labels_map[label]))

    random.seed(1337)
    random.shuffle(items)

    total_items = len(items)
    train_count = int(0.8 * total_items)
    val_count = int(0.1 * total_items)
    print(f"Found {total_items} clips in total, splitting to train ({train_count}) and val ({val_count}).")

    splits = {
        # "train.csv": items[:train_count],
        "train.csv": items[:],
        "val.csv": items[train_count:train_count + val_count],
        "test.csv": items[train_count + val_count:]
    }

    for name, data in splits.items():
        with open(os.path.join(config.splits_dir, name), "w", newline="") as f:
            csv_writer = csv.writer(f)
            csv_writer.writerows(data)

    print({k: len(v) for k, v in splits.items()})

In [112]:
prepare_data_splits(cfg)

Found 660 clips in total, splitting to train (528) and val (66).
{'train.csv': 660, 'val.csv': 66, 'test.csv': 66}


## **3. DATASET**
### The ClipDataset class handles video loading and preprocessing.

In [125]:
class ClipDataset(Dataset):
    def __init__(self, csv_path: str, config: Config, train: bool = True):
        self.items = [(p, int(y)) for p, y in csv.reader(open(csv_path))]
        self.config = config
        self.train = train

        # Pre-compute normalization tensors
        self.mean = torch.tensor([0.45, 0.45, 0.45]).view(3, 1, 1)
        self.std = torch.tensor([0.225, 0.225, 0.225]).view(3, 1, 1)

        # Define a composed transform for training
        if self.train:
            self.train_transforms = v2.Compose([
                # v2.RandomResizedCrop(
                #     size=self.config.side,
                #     scale=(0.85, 1.0),
                #     ratio=(0.85, 1.2),
                #     antialias=True
                # ),
                # random affine is very expensive...
                v2.RandomAffine(           # tiny pure-translate + mild scale jitter
                    degrees=0,
                    translate=(0.06, 0.06),   # up to ±6% of W/H
                    scale=(0.85, 1.15)
                ),
                v2.RandomHorizontalFlip(p=0.3),
                # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.5),
                # v2.RandomGrayscale(p=0.2),
            ])

    def _get_frame_indices(self, num_frames: int):
        """
        TSN-style: take 'need' samples uniformly over the whole [0, num_frames-1],
        with small per-segment jitter. Optional speed-jitter + temporal-dropout for robustness.
        """
        need = self.config.fast_t
        if num_frames <= 1:
            fast_idx = [0] * need
        else:
            # Uniform positions
            base = np.linspace(0, num_frames - 1, num=need)
            if self.train:
                # (a) Segment jitter
                seg = max((num_frames - 1) / max(need, 1), 1.0)
                jitter = np.random.uniform(-0.5, 0.5, size=need) * seg
                base = np.clip(base + jitter, 0, num_frames - 1)
    
                # (b) Speed jitter (time-warp)
                s = np.random.uniform(0.8, 1.25)
                center = (num_frames - 1) / 2.0
                base = (base - center) * s + center
                base = np.clip(base, 0, num_frames - 1)
    
                # (c) Temporal dropout (optional)
                if np.random.rand() < 0.30 and need >= 8:
                    drop = np.random.randint(1, need // 4 + 1)
                    keep_mask = np.ones(need, dtype=bool)
                    keep_mask[np.random.choice(need, drop, replace=False)] = False
                    kept = base[keep_mask]
                    if kept.size == 0:
                        kept = np.array([0.0])
                    pad = np.full(need - kept.size, kept[-1])
                    base = np.sort(np.concatenate([kept, pad]))
    
            fast_idx = base.astype(int).tolist()
    
        slow_idx = fast_idx[::self.config.alpha]
        if len(slow_idx) == 0:
            slow_idx = [fast_idx[0]]  # safety
        return slow_idx, fast_idx

    def _read_and_process_frames(self, vr: VideoReader, indices: List[int]) -> torch.Tensor:
        """
        Returns (C, T, H, W) normalized to kinetics-style mean/std.
        """
        try:
            frames = vr.get_batch([min(i, len(vr)-1) for i in indices]).asnumpy()
        except Exception:
            frames = np.stack([vr[min(i, len(vr)-1)].asnumpy() for i in indices], axis=0)

        # Convert to tensor and permute dimensions
        x = torch.from_numpy(frames).permute(0, 3, 1, 2).float() / 255.0  # (T, C, H, W)

        # Apply data augmentation only for training
        if self.train:
            # Apply the same random transform to all frames
            x = self.train_transforms(x)

        # Resize to the required size if necessary
        x = F.interpolate(x, size=self.config.side, mode="bilinear", align_corners=False) # (T, C, 224, 224)

        # Normalize
        mean = self.mean.to(x)
        std = self.std.to(x)
        x = (x - mean) / std

        return x.permute(1, 0, 2, 3) # (C, T, H, W)

    def __getitem__(self, i: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], int]:
        """Loads and preprocesses a single clip and its label."""
        path, label = self.items[i]
        vr = VideoReader(path, ctx=cpu(0))

        # Randomly choose frames from the entire video
        slow_indices, fast_indices = self._get_frame_indices(len(vr))

        # Get and process clips
        slow_clip = self._read_and_process_frames(vr, slow_indices)
        fast_clip = self._read_and_process_frames(vr, fast_indices)

        return (slow_clip, fast_clip), label

    def __len__(self) -> int:
        return len(self.items)

#### **Generate datasets and loaders for training, validation, and testing**

In [126]:
def slowfast_collate(batch):
    # batch: list of [((slow, fast), y), ...]
    slows, fasts, ys = [], [], []
    for (s, f), y in batch:
        slows.append(s)
        fasts.append(f)
        ys.append(y)
    slow = torch.stack(slows, dim=0)  # (B,C,T,H,W)
    fast = torch.stack(fasts, dim=0)  # (B,C,T,H,W)
    y = torch.tensor(ys, dtype=torch.long)
    return [slow, fast], y

train_csv = os.path.join(cfg.splits_dir, "train.csv")
val_csv   = os.path.join(cfg.splits_dir, "val.csv")
test_csv  = os.path.join(cfg.splits_dir, "test.csv")

train_ds = ClipDataset(train_csv, cfg, train=True)
val_ds   = ClipDataset(val_csv,   cfg, train=False)
test_ds  = ClipDataset(test_csv,  cfg, train=False)

test_loader = DataLoader(
    test_ds, batch_size=max(1, cfg.batch_size), shuffle=False,
    num_workers=4, pin_memory=True, collate_fn=slowfast_collate, persistent_workers=False
)

train_loader = DataLoader(
    train_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=4, pin_memory=True, collate_fn=slowfast_collate, persistent_workers=False
)
val_loader = DataLoader(
    val_ds, batch_size=max(1, cfg.batch_size), shuffle=False,
    num_workers=4, pin_memory=True, collate_fn=slowfast_collate, persistent_workers=False
)

num_classes = len(cfg.labels)
print("Classes:", num_classes, cfg.labels)

Classes: 13 ['block', 'clear', 'cross_net', 'drive', 'drop', 'jump_smash', 'lift', 'push', 'serve', 'smash', 'straight_net', 'tap', 'negative']


In [26]:
cheras_ds = ClipDataset('splits_cheras/train.csv', cfg, train=True)

cheras_loader = DataLoader(
    cheras_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=4, pin_memory=True, collate_fn=slowfast_collate, persistent_workers=False
)

In [86]:
count_class = {}
for _,outs in train_ds:
    labels = cfg.labels[outs]
    if labels not in count_class:
        count_class[labels] = 0
    count_class[labels] += 1
    
print(count_class)

print(f"total clips: {sum(count_class.values())}")

{'jump_smash': 23, 'cross_net': 38, 'lift': 88, 'push': 27, 'block': 36, 'straight_net': 44, 'negative': 172, 'drop': 39, 'drive': 17, 'serve': 6, 'smash': 15, 'tap': 4, 'clear': 19}
total clips: 528


In [45]:
# without negatives

class_counts = torch.tensor([
    20, 35, 10, 
    50, 20, 15, 
    40, 20, 40, 
    40, 8, 5
], dtype=torch.float32) # Order must match cfg.labels!!!!

# [
#             "block", "clear", "cross_net",
#             "drive", "drop", "jump_smash",
#             "lift", "push", "serve",
#             "smash", "straight_net", "tap"
#  ]

class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_counts)
class_weights = class_weights.to(device)

In [127]:
# WITH negatives

class_counts = torch.tensor([
    33, 27, 36, 
    16, 41, 22, 
    85, 23, 5, 
    19, 41, 3,
    10  # important to classify negatives correctly 
], dtype=torch.float32) # Order must match cfg.labels!!!!

# [
#             "block", "clear", "cross_net",
#             "drive", "drop", "jump_smash",
#             "lift", "push", "serve",
#             "smash", "straight_net", "tap",
#             "negatives"
#  ]

class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_counts)
class_weights = class_weights.to(device)

## **4. TRAINING AND EVALUATION**
### This function orchestrates the entire training process.

#### **Load pre-trained model from hub**

In [246]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SlowFastLSTMHead(nn.Module):
    """
    Replace ResNetBasicHead for SlowFast.
    Accepts x as [x_slow, x_fast] each shaped (B, C, T, H, W) or a single tensor.
    - Spatial global average pool -> keep T
    - Temporal BiLSTM + attention pooling -> logits
    """
    def __init__(self,
                 in_dim: int,                  # Cs + Cf from original head
                 num_classes: int,
                 lstm_hidden: int = 512,
                 lstm_layers: int = 1,
                 bidirectional: bool = True,
                 lstm_dropout: float = 0.3,    # used if lstm_layers > 1
                 head_hidden: int | None = None,
                 head_dropout: float = 0.25):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=in_dim,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=(lstm_dropout if lstm_layers > 1 else 0.0),
        )
        out_dim = lstm_hidden * (2 if bidirectional else 1)

        # simple additive attention over time
        self.attn = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Linear(out_dim, out_dim // 2),
            nn.Tanh(),
            nn.Linear(out_dim // 2, 1),
        )

        if head_hidden is None:
            self.classifier = nn.Sequential(
                nn.LayerNorm(out_dim),
                nn.Dropout(head_dropout),
                nn.Linear(out_dim, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.LayerNorm(out_dim),
                nn.Linear(out_dim, head_hidden),
                nn.ReLU(inplace=True),
                nn.Dropout(head_dropout),
                nn.Linear(head_hidden, num_classes),
            )

    @staticmethod
    def _spatial_pool_keep_time(x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, T, H, W) -> (B, T, C) by spatial mean
        x = x.mean(dim=(-2, -1))      # (B, C, T)
        x = x.permute(0, 2, 1)        # (B, T, C)
        return x

    def forward(self, x):
        # x can be a list/tuple [slow, fast] or a single tensor
        if isinstance(x, (list, tuple)):
            xs, xf = x  # (B, Cs, Ts, H, W), (B, Cf, Tf, H, W)
            Ts, Tf = xs.shape[2], xf.shape[2]
            # Align fast to slow along time if needed
            if Tf != Ts:
                if Tf % Ts == 0:
                    r = Tf // Ts
                    xf = F.avg_pool3d(xf, kernel_size=(r, 1, 1), stride=(r, 1, 1))
                else:
                    xf = F.interpolate(xf, size=(Ts, xf.shape[-2], xf.shape[-1]),
                                       mode="trilinear", align_corners=False)
            xs = self._spatial_pool_keep_time(xs)   # (B, Ts, Cs)
            xf = self._spatial_pool_keep_time(xf)   # (B, Ts, Cf)
            xseq = torch.cat([xs, xf], dim=-1)      # (B, Ts, Cs+Cf)
        else:
            xseq = self._spatial_pool_keep_time(x)   # (B, T, C)

        lstm_out, _ = self.lstm(xseq)                # (B, T, H*)
        scores = self.attn(lstm_out).squeeze(-1)     # (B, T)
        w = torch.softmax(scores, dim=1).unsqueeze(-1)
        pooled = (lstm_out * w).sum(dim=1)           # (B, H*)
        logits = self.classifier(pooled)             # (B, num_classes)
        return logits


In [128]:
# =========================
# 3) Model: load hub, replace head
# =========================
# 1) Load SlowFast
torch.hub._validate_not_a_forked_repo = lambda a,b,c: True
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r101', pretrained=True)

# 2) Compute input dim for the new head (Cs + Cf)
in_dim = model.blocks[-1].proj.in_features   # same sum of channels used by the original linear head

# 3) Replace the entire head (not just proj)
num_classes = len(cfg.labels)  # your label count
model.blocks[-1].proj = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(in_dim, 128),
    nn.Dropout(p=0.3), # Add a dropout layer
    nn.Linear(128, num_classes)
)

# Optional: freeze early blocks for faster convergence at small data sizes
for p in model.blocks[:-1].parameters():
    p.requires_grad = False

model = model.to(device)

Using cache found in /root/.cache/torch/hub/facebookresearch_pytorchvideo_main


In [133]:
"""
Unfreeze last n blocks
"""

# for p in model.parameters():
#     p.requires_grad = False

for p in model.blocks[:-6].parameters():
    p.requires_grad = False

In [13]:
len(model.blocks)

7

#### **Optional: load weights from checkpoint**

In [None]:
checkpoint_path = 'models_slowfast/slowfast_full_1b.pt'

# Load the saved checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)

# Load the model's state_dict from the checkpoint
model.load_state_dict(checkpoint['model'])
print(f"Model weights loaded successfully from {checkpoint_path}")

#### **Define training components**

In [134]:
# =========================
# 4) Optimizer, loss, metrics
# =========================
from torch.optim.lr_scheduler import CosineAnnealingLR

num_epochs = 20
learning_rate = 0.0002
weight_decay = 0.002
early_stopping_patience = 5

criterion = nn.CrossEntropyLoss(label_smoothing=0.02, weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))

acc = MulticlassAccuracy(num_classes=num_classes, average='micro').to(device)
f1  = MulticlassF1Score(num_classes=num_classes, average='macro').to(device)

  scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))


#### **Main training loop**

In [135]:
# =========================
# 5) Train / validate
# =========================
best_f1 = -1.0
best_train_f1 = -1.0
os.makedirs(cfg.models_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using device: {device}")

for epoch in range(num_epochs):
    model.train()
    acc.reset(); f1.reset()
    total_loss = 0.0

    first = True
    for (slow_fast, y) in tqdm(train_loader, desc=f'Training epoch {epoch+1}/{num_epochs}'):
        if first:
            s, f = slow_fast
            # print("slow:", tuple(s.shape), "fast:", tuple(f.shape))
            # Expect slow=(B,3,8,224,224) and fast=(B,3,32,224,224)
            first = False

        # slow_fast is [slow, fast]
        slow_fast = [t.to(device, non_blocking=True) for t in slow_fast]
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            logits = model(slow_fast)     # (B, num_classes)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * y.size(0)
        acc.update(logits, y)
        f1.update(logits, y)

    train_loss = total_loss / len(train_ds)
    train_acc  = acc.compute().item()
    train_f1   = f1.compute().item()

    # --- Validation ---
    model.eval()
    acc.reset(); f1.reset()
    val_loss = 0.0
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=(device == "cuda")):
        for (slow_fast, y) in val_loader:
            slow_fast = [t.to(device, non_blocking=True) for t in slow_fast]
            y = y.to(device, non_blocking=True)
            logits = model(slow_fast)
            loss = criterion(logits, y)
            val_loss += loss.item() * y.size(0)
            acc.update(logits, y)
            f1.update(logits, y)

    val_loss /= len(val_ds)
    val_acc = acc.compute().item()
    val_f1  = f1.compute().item()

    # scheduler.step(val_f1)  # for reduceLRonPlateau
    scheduler.step()   # for cosineannealing

    print(f"\n[{epoch+1:02d}/{num_epochs}] "
          f"train_loss={train_loss:.4f} acc={train_acc*100:.2f}% f1={train_f1:.3f} | "
          f"val_loss={val_loss:.4f} acc={val_acc*100:.2f}% f1={val_f1:.3f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save({"model": model.state_dict(), "labels": cfg.labels}, cfg.best_model_path)
        print(f"  ↳ saved new best to {cfg.best_model_path} (val_f1={best_f1:.3f})")

    if train_f1 > best_train_f1:
        best_train_f1 = train_f1
        torch.save({"model": model.state_dict(), "labels": cfg.labels}, f'{cfg.models_dir}/best_train_f1.pt')
        print(f"  ↳ saved new best to '{cfg.models_dir}/best_train_f1.pt' (train_f1={best_train_f1:.3f})")

torch.save({"model": model.state_dict(), "labels": cfg.labels}, 'train_cheras/last.pt')

print("Best val F1:", best_f1)

using device: cuda


  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
Training epoch 1/20: 100%|██████████| 83/83 [00:20<00:00,  3.97it/s]
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=(device == "cuda")):



[01/20] train_loss=0.6573 acc=68.94% f1=0.712 | val_loss=0.4117 acc=86.36% f1=0.876
  ↳ saved new best to ./train_finals_v1/best.pt (val_f1=0.876)
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.712)


Training epoch 2/20: 100%|██████████| 83/83 [00:19<00:00,  4.23it/s]



[02/20] train_loss=0.6473 acc=71.36% f1=0.756 | val_loss=0.4022 acc=78.79% f1=0.781
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.756)


Training epoch 3/20: 100%|██████████| 83/83 [00:19<00:00,  4.28it/s]



[03/20] train_loss=0.6170 acc=73.48% f1=0.760 | val_loss=0.4111 acc=81.82% f1=0.812
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.760)


Training epoch 4/20: 100%|██████████| 83/83 [00:19<00:00,  4.22it/s]



[04/20] train_loss=0.5982 acc=73.03% f1=0.754 | val_loss=0.3986 acc=80.30% f1=0.800


Training epoch 5/20: 100%|██████████| 83/83 [00:19<00:00,  4.19it/s]



[05/20] train_loss=0.5998 acc=74.70% f1=0.776 | val_loss=0.3965 acc=80.30% f1=0.792
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.776)


Training epoch 6/20: 100%|██████████| 83/83 [00:19<00:00,  4.27it/s]



[06/20] train_loss=0.6305 acc=74.24% f1=0.766 | val_loss=0.3975 acc=80.30% f1=0.800


Training epoch 7/20: 100%|██████████| 83/83 [00:19<00:00,  4.17it/s]



[07/20] train_loss=0.6586 acc=70.00% f1=0.725 | val_loss=0.3810 acc=87.88% f1=0.882
  ↳ saved new best to ./train_finals_v1/best.pt (val_f1=0.882)


Training epoch 8/20: 100%|██████████| 83/83 [00:19<00:00,  4.19it/s]



[08/20] train_loss=0.6073 acc=73.64% f1=0.780 | val_loss=0.3738 acc=83.33% f1=0.837
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.780)


Training epoch 9/20: 100%|██████████| 83/83 [00:19<00:00,  4.23it/s]



[09/20] train_loss=0.6029 acc=70.91% f1=0.751 | val_loss=0.3621 acc=87.88% f1=0.885
  ↳ saved new best to ./train_finals_v1/best.pt (val_f1=0.885)


Training epoch 10/20: 100%|██████████| 83/83 [00:19<00:00,  4.17it/s]



[10/20] train_loss=0.6340 acc=73.64% f1=0.769 | val_loss=0.4185 acc=89.39% f1=0.875


Training epoch 11/20: 100%|██████████| 83/83 [00:19<00:00,  4.21it/s]



[11/20] train_loss=0.6113 acc=71.52% f1=0.750 | val_loss=0.3719 acc=86.36% f1=0.866


Training epoch 12/20: 100%|██████████| 83/83 [00:19<00:00,  4.22it/s]



[12/20] train_loss=0.5735 acc=74.85% f1=0.762 | val_loss=0.3639 acc=87.88% f1=0.885


Training epoch 13/20: 100%|██████████| 83/83 [00:19<00:00,  4.21it/s]



[13/20] train_loss=0.5656 acc=75.91% f1=0.775 | val_loss=0.3707 acc=86.36% f1=0.868


Training epoch 14/20: 100%|██████████| 83/83 [00:19<00:00,  4.25it/s]



[14/20] train_loss=0.5678 acc=75.91% f1=0.786 | val_loss=0.3604 acc=83.33% f1=0.823
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.786)


Training epoch 15/20: 100%|██████████| 83/83 [00:19<00:00,  4.22it/s]



[15/20] train_loss=0.5855 acc=74.24% f1=0.771 | val_loss=0.3632 acc=86.36% f1=0.871


Training epoch 16/20: 100%|██████████| 83/83 [00:19<00:00,  4.26it/s]



[16/20] train_loss=0.5881 acc=74.39% f1=0.790 | val_loss=0.3650 acc=83.33% f1=0.828
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.790)


Training epoch 17/20: 100%|██████████| 83/83 [00:19<00:00,  4.21it/s]



[17/20] train_loss=0.6127 acc=72.42% f1=0.762 | val_loss=0.3593 acc=83.33% f1=0.840


Training epoch 18/20: 100%|██████████| 83/83 [00:20<00:00,  4.10it/s]



[18/20] train_loss=0.5944 acc=73.18% f1=0.779 | val_loss=0.3590 acc=83.33% f1=0.840


Training epoch 19/20: 100%|██████████| 83/83 [00:19<00:00,  4.21it/s]



[19/20] train_loss=0.5772 acc=76.36% f1=0.799 | val_loss=0.3664 acc=86.36% f1=0.871
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.799)


Training epoch 20/20: 100%|██████████| 83/83 [00:19<00:00,  4.17it/s]



[20/20] train_loss=0.5786 acc=76.06% f1=0.801 | val_loss=0.3655 acc=84.85% f1=0.860
  ↳ saved new best to './train_finals_v1/best_train_f1.pt' (train_f1=0.801)
Best val F1: 0.8853994607925415


In [136]:
# SAVE TO MODELS_SLOWFAST FOLDER

!cp train_finals_v1/best_train_f1.pt models_slowfast/slowfast_finals_4b.pt

## **Evaluate on test set**

In [656]:
def load_canonical_labels(splits_dir: str):
    import json, os
    with open(os.path.join(splits_dir, "labels_map.json"), "r") as f:
        m = json.load(f)  # {"block":0, "clear":1, ...}
    return [k for k,_ in sorted(m.items(), key=lambda kv: kv[1])]

In [657]:
cfg.labels = load_canonical_labels(cfg.splits_dir)
cfg.labels

['smash',
 'jump_smash',
 'block',
 'drop',
 'clear',
 'lift',
 'drive',
 'straight_net',
 'cross_net',
 'serve',
 'push',
 'tap']

In [60]:
def get_labels_from_ckpt(ckpt):
    # Try common keys
    for k in ("labels", "label_names", "classes"):
        if k in ckpt:
            return list(ckpt[k])
    # Sometimes nested in meta/config
    for outer in ("meta", "config", "hparams"):
        if outer in ckpt:
            for k in ("labels", "label_names", "classes"):
                if k in ckpt[outer]:
                    return list(ckpt[outer][k])
    return None

def make_permutation(src_order, tgt_order):
    # src_order: order used to train the checkpoint head
    # tgt_order: canonical (test) order
    if set(src_order) != set(tgt_order):
        missing = set(tgt_order) - set(src_order)
        extra   = set(src_order) - set(tgt_order)
        raise ValueError(f"Label set mismatch.\nMissing in ckpt: {missing}\nExtra in ckpt: {extra}")
    return [src_order.index(c) for c in tgt_order]  # index in src for each tgt slot

def realign_linear_head_to_canonical(model, perm):
    """
    Permute the classifier head rows so its output index ordering == canonical.
    Works when your classifier is model.blocks[-1].proj[-1] (nn.Linear).
    """
    head = model.blocks[-1].proj[-1]
    import torch
    with torch.no_grad():
        head.weight.data = head.weight.data[perm, :]
        if head.bias is not None:
            head.bias.data = head.bias.data[perm]
    return model

In [65]:
import time
class TestManager:
    """
    Manages the evaluation process for a SlowFast model on a test set.
    """
    def __init__(self, config: 'Config', device: str, model_path: str, loader: DataLoader = test_loader):
        self.config = config
        self.device = device
        self.num_classes = len(config.labels)
        self.model = self._load_model(model_path)
        self.loader = loader
        self.metrics = self._initialize_metrics()
        self.softmax = nn.Softmax(dim=1)

    def _load_model(self, model_path):
        """Loads the pre-trained SlowFast model and the fine-tuned checkpoint."""
        print("Loading model and best checkpoint...")

        # Disable the internal hub check for local loading
        torch.hub._validate_not_a_forked_repo = lambda a,b,c: True

        model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r101', pretrained=True)
        in_dim = model.blocks[-1].proj.in_features
        # model.blocks[-1] = SlowFastLSTMHead(
        #     in_dim=in_dim,
        #     num_classes=num_classes,
        #     lstm_hidden=512,
        #     lstm_layers=3,          # try 1–2; >2 rarely helps at T<=16
        #     bidirectional=True,
        #     head_hidden=None,       # or 256 for a tiny MLP before the final linear
        #     head_dropout=0.25
        # )

        model.blocks[-1].proj = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_dim, 128),
            nn.Dropout(p=0.3), # Add a dropout layer
            nn.Linear(128, num_classes)
        )

        # Load the state dictionary from the checkpoint file
        ckpt = torch.load(model_path, map_location=self.device)
        model.load_state_dict(ckpt["model"])

        src_order = get_labels_from_ckpt(ckpt)
        if src_order is None:
            print("⚠️  No labels found in checkpoint; outputs assumed already canonical.")
        else:
            tgt_order = list(self.config.labels)  # canonical test order
            perm = make_permutation(src_order, tgt_order)  # src index per tgt slot
            if perm != list(range(len(perm))):
                # Permanently realign head to canonical order
                model = realign_linear_head_to_canonical(model, perm)
                print("✅ Classifier head realigned to canonical label order.")
            else:
                print("ℹ️ Head already matches canonical order.")
            
        model = model.to(self.device)
        model.eval()
        return model

    def _create_dataloader(self):
        """Creates and returns the DataLoader for the test set."""
        test_ds = ClipDataset(os.path.join(self.config.splits_dir, "test.csv"), self.config, train=False)
        return DataLoader(
            test_ds,
            batch_size=max(1, self.config.batch_size),
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            # collate_fn=slowfast_collate,  # Make sure this is imported if needed
            persistent_workers=False
        )

    def _initialize_metrics(self):
        """Initializes all the evaluation metrics."""
        return {
            'top1': MulticlassAccuracy(num_classes=self.num_classes, average="micro").to(self.device),
            'top3': MulticlassAccuracy(num_classes=self.num_classes, top_k=3).to(self.device),
            'f1_macro': MulticlassF1Score(num_classes=self.num_classes, average="macro").to(self.device),
            'f1_micro': MulticlassF1Score(num_classes=self.num_classes, average="micro").to(self.device),
            'f1_perclass': MulticlassF1Score(num_classes=self.num_classes, average=None).to(self.device),
            'cm': MulticlassConfusionMatrix(num_classes=self.num_classes).to(self.device)
        }

    def run_inference(self):
        """Runs the inference loop and computes all metrics and predictions."""
        print("Starting inference on the test set...")
        test_loss = 0.0
        criterion = nn.CrossEntropyLoss()
        all_predictions = []

        with torch.no_grad(), torch.amp.autocast(self.device, enabled=(self.device == "cuda")):
            for batch_idx, (slow_fast, y) in enumerate(self.loader):
                # Ensure input tensors are lists
                if not isinstance(slow_fast, list):
                    slow_fast = [slow_fast]

                slow_fast = [t.to(self.device, non_blocking=True) for t in slow_fast]
                y = y.to(self.device, non_blocking=True)

                logits = self.model(slow_fast)
                loss = criterion(logits, y)
                test_loss += loss.item() * y.size(0)

                # Update metrics
                for metric in self.metrics.values():
                    metric.update(logits, y)

                # Collect per-sample predictions for later saving
                probs = self.softmax(logits)
                conf, pred = probs.max(dim=1)
                topk_conf, topk_idx = probs.topk(3, dim=1)

                start_idx = batch_idx * self.loader.batch_size

                for i in range(y.size(0)):
                    idx = start_idx + i
                    path = self.loader.dataset.items[idx][0]
                    row = {
                        "path": path,
                        "file": os.path.basename(path),
                        "true_idx": int(y[i]),
                        "true_label": self.config.labels[int(y[i])],
                        "pred_idx": int(pred[i]),
                        "pred_label": self.config.labels[int(pred[i])],
                        "pred_prob": float(conf[i]),
                        "top1_label": self.config.labels[int(topk_idx[i,0])],
                        "top1_prob":  float(topk_conf[i,0]),
                        "top2_label": self.config.labels[int(topk_idx[i,1])],
                        "top2_prob":  float(topk_conf[i,1]),
                        "top3_label": self.config.labels[int(topk_idx[i,2])],
                        "top3_prob":  float(topk_conf[i,2]),
                    }
                    all_predictions.append(row)

        test_loss /= len(self.loader.dataset)
        return test_loss, all_predictions

    def compute_and_print_results(self, test_loss):
        acc1 = self.metrics['top1'].compute().item()
        acc3 = self.metrics['top3'].compute().item()
        f1M = self.metrics['f1_macro'].compute().item()
        f1m = self.metrics['f1_micro'].compute().item()
        percls = self.metrics['f1_perclass'].compute().detach().cpu().tolist()
        confmat = self.metrics['cm'].compute().detach().cpu().numpy()
    
        print(f"\nTEST: loss={test_loss:.4f} | acc@1={acc1*100:.2f}% | acc@3={acc3*100:.2f}% | macro-F1={f1M:.3f} | micro-F1={f1m:.3f}")
        print("\nPer-class F1:")
        for lab, s in sorted(zip(self.config.labels, percls), key=lambda x: x[1], reverse=True):
            print(f"  {lab:15s} {s:.3f}")
    
        print("\nConfusion Matrix (rows=true, cols=predicted):")
        print(confmat)
    
        # NEW: return everything needed for plotting/report
        return {
            "test_loss": float(test_loss),
            "acc1": float(acc1),
            "acc3": float(acc3),
            "f1_macro": float(f1M),
            "f1_micro": float(f1m),
            "f1_perclass": {lab: float(s) for lab, s in zip(self.config.labels, percls)},
            "confmat": confmat,  # keep as numpy array for plotting
        }

    def save_predictions(self, predictions: list, print_n: int=10, filepath: str = '/opt/NeMo/eval_plots/csv_1.csv'):
        """Saves the list of predictions to a CSV file."""
        df = pd.DataFrame(predictions)
        save_path = os.path.join(self.config.models_dir, "test_predictions.csv")
        df.to_csv(filepath, index=False)
        print(f"\nSaved per-sample predictions to: {save_path}")
        print("\nQuick peek at the predictions:")
        with pd.option_context('display.max_rows', None):
            print(df.head(print_n)[[
                # "file", 
                "true_label", "pred_label", "pred_prob", 
                "top2_label", "top2_prob"]]
                 )

    def render_report(self, predictions: list, results: dict, out_dir: str = None):
        """
        Build charts + HTML report from your results and predictions.
        """
        ts = time.strftime("%Y%m%d-%H%M%S")
        out_dir = out_dir or os.path.join(self.config.models_dir, f"test_report_{ts}")
        os.makedirs(out_dir, exist_ok=True)
    
        # Predictions to DataFrame for calibration plot
        df = pd.DataFrame(predictions)  # from save_predictions() structure
        # In case you didn't save yet, still fine: df has pred_prob/true_idx/pred_idx
    
        rep = EvalReport(self.config.labels, out_dir)
    
        # Plots
        cm_raw_path = rep.plot_confusion(results["confmat"], normalize=False, name="cm_raw.png")
        cm_norm_path = rep.plot_confusion(results["confmat"], normalize=True,  name="cm_norm.png")
        f1_path      = rep.plot_perclass_bars(results["f1_perclass"], "Per-class F1 (sorted)", "f1_perclass.png")
        support_path = rep.plot_support_bars(results["confmat"], name="support_true_pred.png")
        topconf_path = rep.plot_top_confusions(results["confmat"], k=5, name="top_confusions.png") or ""
        calib_path   = rep.plot_calibration(df, name="calibration.png", bins=10)
    
        # Summary numbers for the HTML header
        summary = {
            "test_loss": results["test_loss"],
            "acc@1": results["acc1"],
            "acc@3": results["acc3"],
            "macro-F1": results["f1_macro"],
            "num_samples": int(len(df)),
            "num_classes": int(len(self.config.labels)),
        }
    
        images = {
            "Confusion Matrix (Normalized)": cm_norm_path,
            "Confusion Matrix (Raw Counts)": cm_raw_path,
            "Per-class F1": f1_path,
            "Class Support (True vs Pred)": support_path,
            "Reliability Diagram (Calibration)": calib_path,
        }
        if topconf_path:
            images["Top Confusions (True→Pred)"] = topconf_path
    
        html_path = rep.write_html(summary, images, name="report.html")
        print(f"\n✅ Saved evaluation report to: {html_path}")
        print(f"   (and PNGs in: {out_dir})")

In [538]:
Video('cheras/clear/291351dfbe.mp4', embed=True, width=640, height=480)

In [663]:
fixed_test_ds  = ClipDataset('/opt/NeMo/splits_final/test.csv',  cfg, train=False)

fixed_test_loader = DataLoader(
    fixed_test_ds, batch_size=128, shuffle=False,
    num_workers=4, pin_memory=True, collate_fn=slowfast_collate, persistent_workers=False
)

In [137]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# test_manager = TestManager(cfg, device, cfg.best_model_path, loader=test_loader)
test_manager = TestManager(cfg, device, 'models_slowfast/slowfast_finals_4b.pt', loader=train_loader)

test_loss, all_predictions = test_manager.run_inference()
res = test_manager.compute_and_print_results(test_loss)
test_manager.save_predictions(all_predictions, print_n=len(all_predictions), filepath='/opt/NeMo/eval_plots/preds_high.csv')

Loading model and best checkpoint...


Using cache found in /root/.cache/torch/hub/facebookresearch_pytorchvideo_main


ℹ️ Head already matches canonical order.
Starting inference on the test set...

TEST: loss=0.4639 | acc@1=86.21% | acc@3=99.47% | macro-F1=0.891 | micro-F1=0.862

Per-class F1:
  jump_smash      1.000
  serve           1.000
  drive           0.973
  clear           0.966
  smash           0.960
  drop            0.957
  block           0.927
  straight_net    0.914
  negative        0.858
  push            0.839
  cross_net       0.795
  tap             0.750
  lift            0.645

Confusion Matrix (rows=true, cols=predicted):
[[ 38   0   0   0   0   0   0   0   0   0   0   0   4]
 [  0  28   0   1   0   0   0   0   0   0   0   0   1]
 [  0   0  35   0   0   0   0   0   0   0   0   0   7]
 [  0   0   0  18   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0  44   0   0   0   0   1   0   0   3]
 [  0   0   0   0   0  30   0   0   0   0   0   0   0]
 [  1   0   8   0   0   0  49   0   0   0   1   4  39]
 [  0   0   0   0   0   0   0  26   0   0   0   0  10]
 [  0   0   0   0   0   

In [665]:
res['confmat']
# res['f1_perclass']

array([[ 2,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0, 13,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  6,  0,  0,  0,  1,  0,  0,  0,  0,  0],
       [ 0,  2,  0, 11,  0,  0,  0,  1,  0,  0,  1,  0],
       [ 1,  2,  0,  4,  3,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  1, 19,  0,  9,  2,  0,  0,  0],
       [ 0,  0,  0,  2,  0,  1,  7,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  1,  0, 19,  0,  0,  1,  0],
       [ 0,  0,  1,  0,  0,  2,  0,  1,  9,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  1,  0, 16,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  4,  3,  0,  4,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [666]:
import torch, os, time, json

def save_results_pickle(results: dict, out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    payload = {
        "schema": "eval_results.v1",
        "saved_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        **results,  # includes confmat (numpy), f1_perclass (dict), etc.
    }
    torch.save(payload, out_path)   # uses pickle under the hood
    print(f"✓ Saved: {out_path}")

def load_results_pickle(path: str) -> dict:
    obj = torch.load(path, map_location="cpu")
    assert obj.get("schema") == "eval_results.v1"
    return obj

save_results_pickle(res, "/opt/NeMo/eval_plots/eval_results_high.pt")

✓ Saved: /opt/NeMo/eval_plots/eval_results_high.pt


### **If found a good model, save into proper folder**

In [637]:
!cp train_full/best.pt models_slowfast/slowfast_full_1f.pt

In [627]:
!cp train_cheras/best_train_f1.pt models_slowfast/slowfast_full_1e.pt

# **Print metrics**

In [644]:
from report_utils import EvalReport

In [645]:
results = test_manager.compute_and_print_results(test_loss)
test_manager.render_report(all_predictions, results)


TEST: loss=1.3414 | acc@1=56.21% | acc@3=81.01% | macro-F1=0.481

Per-class F1:
  serve           0.839
  jump_smash      0.703
  straight_net    0.600
  lift            0.596
  block           0.571
  drop            0.519
  clear           0.435
  cross_net       0.400
  drive           0.400
  push            0.381
  smash           0.333
  tap             0.000

Confusion Matrix (rows=true, cols=predicted):
[[ 6  0  0  1  1  0  0  1  0  0  0  0]
 [ 0  5  0  0  3  2  1  0  0  0  0  1]
 [ 1  0  2  0  0  0  1  1  0  0  1  0]
 [ 2  1  0  4  3  0  0  0  1  1  2  0]
 [ 0  0  0  0  7  2  0  0  0  0  0  0]
 [ 0  1  0  0  1 13  0  0  0  0  0  0]
 [ 1  0  1  1  0  0 14  1  0  0  8  0]
 [ 1  0  1  0  1  0  1  4  0  0  5  0]
 [ 1  0  0  0  0  0  2  0 13  0  1  0]
 [ 0  4  0  0  2  5  0  0  0  3  0  0]
 [ 0  0  0  0  0  0  2  1  0  0 15  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0]]


ValueError: keyword ha is not recognized; valid keywords are ['size', 'width', 'color', 'tickdir', 'pad', 'labelsize', 'labelcolor', 'labelfontfamily', 'zorder', 'gridOn', 'tick1On', 'tick2On', 'label1On', 'label2On', 'length', 'direction', 'left', 'bottom', 'right', 'top', 'labelleft', 'labelbottom', 'labelright', 'labeltop', 'labelrotation', 'grid_agg_filter', 'grid_alpha', 'grid_animated', 'grid_antialiased', 'grid_clip_box', 'grid_clip_on', 'grid_clip_path', 'grid_color', 'grid_dash_capstyle', 'grid_dash_joinstyle', 'grid_dashes', 'grid_data', 'grid_drawstyle', 'grid_figure', 'grid_fillstyle', 'grid_gapcolor', 'grid_gid', 'grid_in_layout', 'grid_label', 'grid_linestyle', 'grid_linewidth', 'grid_marker', 'grid_markeredgecolor', 'grid_markeredgewidth', 'grid_markerfacecolor', 'grid_markerfacecoloralt', 'grid_markersize', 'grid_markevery', 'grid_mouseover', 'grid_path_effects', 'grid_picker', 'grid_pickradius', 'grid_rasterized', 'grid_sketch_params', 'grid_snap', 'grid_solid_capstyle', 'grid_solid_joinstyle', 'grid_transform', 'grid_url', 'grid_visible', 'grid_xdata', 'grid_ydata', 'grid_zorder', 'grid_aa', 'grid_c', 'grid_ds', 'grid_ls', 'grid_lw', 'grid_mec', 'grid_mew', 'grid_mfc', 'grid_mfcalt', 'grid_ms']

## **end-to-end match inference & overlay**

In [None]:
!pip -q install ultralytics opencv-python-headless

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m1.0/1.1 MB[0m [31m29.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os, cv2, numpy as np, torch
from collections import deque, defaultdict
from ultralytics import YOLO
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def load_slowfast_classifier(cfg, ckpt_path):
    torch.hub._validate_not_a_forked_repo = lambda a,b,c: True
    model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r101', pretrained=True)
    in_dim = model.blocks[-1].proj.in_features
    model.blocks[-1].proj = torch.nn.Linear(in_dim, len(cfg.labels))
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"], strict=True)
    model.eval().to(device)
    return model

In [None]:
def resize_pad_square(img_rgb: np.ndarray, side: int = 224) -> np.ndarray:
    """Keep aspect ratio; resize the longer side to `side`, then pad to (side, side)."""
    h, w = img_rgb.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((side, side, 3), dtype=img_rgb.dtype)
    scale = side / max(h, w)
    nh, nw = int(round(h * scale)), int(round(w * scale))
    resized = cv2.resize(img_rgb, (nw, nh), interpolation=cv2.INTER_LINEAR)
    top  = (side - nh) // 2
    bottom = side - nh - top
    left = (side - nw) // 2
    right = side - nw - left
    out = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128,128,128))
    return out

def expand_box(x1, y1, x2, y2, scale: float, W: int, H: int):
    """Optionally enlarge the bbox to keep some context (e.g., racket)."""
    cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
    bw, bh = (x2 - x1) * scale, (y2 - y1) * scale
    nx1, ny1 = int(max(0, cx - bw / 2)), int(max(0, cy - bh / 2))
    nx2, ny2 = int(min(W - 1, cx + bw / 2)), int(min(H - 1, cy + bh / 2))
    return nx1, ny1, nx2, ny2

In [None]:
class SlowFastPredictor:
    def __init__(self, cfg, model):
        self.cfg = cfg
        self.model = model
        self.mean = torch.tensor([0.45, 0.45, 0.45]).view(3,1,1).to(device)
        self.std  = torch.tensor([0.225, 0.225, 0.225]).view(3,1,1).to(device)

    def _prep(self, frames_rgb_list):
        """
        frames_rgb_list: list of 32 frames, each HxWx3 in RGB
        Returns: [slow, fast] tensors shaped (1,C,T,H,W)
        """
        # Stack to (T,H,W,3) -> (T,C,H,W)
        x = torch.from_numpy(np.stack(frames_rgb_list)).permute(0,3,1,2).float() / 255.0  # (T,C,H,W)
        # Resize treating T as batch
        x = F.interpolate(x, size=self.cfg.side, mode="bilinear", align_corners=False)    # (T,C,224,224)
        # Normalize
        mean = self.mean.to(device=x.device, dtype=x.dtype)
        std  = self.std.to(device=x.device, dtype=x.dtype)
        x = (x - mean) / std                                                   # (T,C,224,224)
        # (C,T,H,W)
        x = x.permute(1,0,2,3)
        fast = x.unsqueeze(0).to(device)             # (1,C,32,224,224)
        slow = x[:, ::self.cfg.alpha, :, :].unsqueeze(0).to(device)  # stride-4 -> (1,C,8,224,224)
        return [slow, fast]

    @torch.no_grad()
    def predict_probs(self, frames_rgb_list):
        assert len(frames_rgb_list) == self.cfg.fast_t  # 32
        with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
            inp = self._prep(frames_rgb_list)
            logits = self.model(inp)                  # (1, num_classes)
            probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
        return probs  # (C,)

In [None]:
def annotate_match_video(
    cfg,
    video_path,
    out_path,
    yolo_weights="yolo11n.pt", # change to your custom weights if you have them
    person_class=0,            # COCO 'person'
    det_conf=0.5,
    iou=0.5,
    pred_thr=0.60,             # minimum prob to show label
    cooldown=12                # frames to cool after showing a shot to reduce spam
):
    # Get video props for the writer
    cap = cv2.VideoCapture(video_path)
    fps = max(1.0, cap.get(cv2.CAP_PROP_FPS))
    W   = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H   = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    cap.release()

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))

    # Load detector+tracker
    yolo = YOLO(yolo_weights)

    # Load classifier
    clf_model = load_slowfast_classifier(cfg, cfg.best_model_path)
    clf = SlowFastPredictor(cfg, clf_model)

    # Per-track state
    buffers = defaultdict(lambda: deque(maxlen=cfg.fast_t))            # 32-frame RGB crops per track
    last_shown_frame = defaultdict(lambda: -99999)                     # cooldown control
    hist = defaultdict(lambda: deque(maxlen=5))                        # small temporal smoothing buffer

    frame_idx = 0
    for res in yolo.track(source=video_path, stream=True, persist=True,
                          classes=[person_class], conf=det_conf, iou=iou, verbose=False):
        frame_bgr = res.orig_img  # BGR
        h, w = frame_bgr.shape[:2]

        # If no boxes/ids in this frame, just write it
        if res.boxes is None or res.boxes.id is None:
            writer.write(frame_bgr)
            frame_idx += 1
            continue

        ids = res.boxes.id.int().cpu().numpy()
        xyxy = res.boxes.xyxy.int().cpu().numpy()  # (N,4)

        to_draw = []  # (x1,y1,x2,y2,label,prob,tid)

        for j, tid in enumerate(ids):
            x1, y1, x2, y2 = xyxy[j]
            x1, y1 = max(0, x1), max(0, y1)
            x2, y2 = min(w-1, x2), min(h-1, y2)
            if x2 <= x1 or y2 <= y1:
                continue

            # NEW: enlarge a bit for context (optional, try 1.2–1.4)
            x1, y1, x2, y2 = expand_box(x1, y1, x2, y2, scale=1.25, W=w, H=h)

            # Crop -> RGB -> letterbox to fixed square
            crop = frame_bgr[y1:y2, x1:x2, :]
            if crop.size == 0:
                continue
            crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
            crop_rgb = resize_pad_square(crop_rgb, side=cfg.side)  # now every frame is 224x224

            buffers[tid].append(crop_rgb)

            label_to_show = None
            prob_to_show  = 0.0

            # Classify when we have a full 32-frame clip
            if len(buffers[tid]) == cfg.fast_t:
                probs = clf.predict_probs(list(buffers[tid]))  # (C,)
                ci = int(probs.argmax())
                pi = float(probs[ci])
                hist[tid].append((ci, pi))

                # Small smoothing: require at least 2 of the last 3 agreeing + prob >= thr
                if len(hist[tid]) >= 3:
                    last3 = list(hist[tid])[-3:]
                else:
                    last3 = list(hist[tid])

                # Choose the label with the highest mean prob among last3
                if last3:
                    classes = [c for c, p in last3 if cfg.labels[c] != "average_joe" and p >= pred_thr]
                    if classes:
                        # pick the most common; break ties by highest avg prob
                        uniq = set(classes)
                        best_c, best_score = None, -1.0
                        for u in uniq:
                            avgp = np.mean([p for (c, p) in last3 if c == u])
                            score = (classes.count(u), avgp)  # (count, avgp)
                            if score > (classes.count(best_c) if best_c is not None else -1, best_score):
                                best_c, best_score = u, avgp
                        if best_c is not None and (frame_idx - last_shown_frame[tid] >= cooldown):
                            label_to_show = cfg.labels[best_c]
                            prob_to_show = float(best_score)
                            last_shown_frame[tid] = frame_idx

            # Queue drawing if we have a confident non-background label
            if label_to_show is not None:
                to_draw.append((x1, y1, x2, y2, label_to_show, prob_to_show, int(tid)))

        # ---- Draw all overlays on this frame ----
        for (x1, y1, x2, y2, lab, p, tid) in to_draw:
            color = (0, 220, 0)
            cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), color, 2)
            txt = f"#{tid} {lab} {p*100:.1f}%"
            cv2.putText(frame_bgr, txt, (x1, max(20, y1-10)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)

        writer.write(frame_bgr)
        frame_idx += 1

    writer.release()
    print(f"Saved annotated video to: {out_path}")

In [None]:
in_video  = "/content/drive/MyDrive/FIT3163,3164/SlowFast/01_raw/lcw_ld_2016_short/1/master.mp4"
out_video = "/content/match_annotated.mp4"
yolo_weights = "/content/drive/MyDrive/FIT3163,3164/YOLO/my_yolov8_1.pt"

annotate_match_video(cfg, in_video, out_video,
                     yolo_weights=yolo_weights,  # swap if you have a better person/badminton model
                     det_conf=0.35, iou=0.5,
                     pred_thr=0.60, cooldown=12)

Using cache found in /root/.cache/torch/hub/facebookresearch_pytorchvideo_main


Saved annotated video to: /content/match_annotated.mp4
