In [None]:
import os
import torch
import numpy as np
import cv2
import pandas as pd
import random
from PIL import Image
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
import scipy.stats

# CONFIGURATION
VIDEO_DIR = "./data/videos/test"
MODEL_PATH = "best_model_task1_epoch5_val0.8305.pt"
FRAME_TMP_DIR = "./inference_frames"
EXCEL_FP = "./data/OSATS_task1.xlsx"
NUM_VIDEOS = 24

NUM_BLOCKS = 12
WINDOW_SEC = 5
FPS = 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TARGET_COLS = [
    "GRS"
]
NUM_OUTPUTS = len(TARGET_COLS)
NUM_CLASSES = 4

# MODEL DEFINITION
class OSATSVideoResNetModel(nn.Module):
    def __init__(self, num_outputs=NUM_OUTPUTS, num_classes=NUM_CLASSES,
                 hidden_dim=128, lstm_layers=1, dropout_rate=0.5, bidirectional=True):
        super().__init__()
        self.num_outputs = num_outputs
        self.num_classes = num_classes
        self.base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.base.fc = nn.Identity()
        self.lstm = nn.LSTM(
            input_size=512, hidden_size=hidden_dim,
            num_layers=lstm_layers, batch_first=True,
            dropout=(dropout_rate if lstm_layers > 1 else 0),
            bidirectional=bidirectional
        )
        lstm_out = hidden_dim * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(lstm_out, 128),
            nn.ReLU(inplace=True),
            nn.LayerNorm(128),
            nn.Dropout(dropout_rate),
            nn.Linear(128, num_outputs * num_classes)
        )

    def forward(self, x):
        B, S, C, H, W = x.shape
        x = x.view(B * S, C, H, W)
        feats = self.base(x)            # [B*S, 512]
        feats = feats.view(B, S, -1)    # [B, S, 512]
        seq_out, _ = self.lstm(feats)
        out = self.head(seq_out)
        return out.view(B, S, self.num_outputs, self.num_classes)

# TRANSFORM
frame_transform = T.Compose([
    T.Resize((256, 256)),
    T.CenterCrop(224),
    T.GaussianBlur(kernel_size=9, sigma=(0.5, 0.5)),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# FRAME EXTRACTION
def extract_frames(video_path, out_dir, fps=1):
    os.makedirs(out_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    frame_idx = 0
    vid_fps = cap.get(cv2.CAP_PROP_FPS)
    interval = int(vid_fps // fps) if fps > 0 and vid_fps > 0 else 1
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count % interval == 0:
            out_path = os.path.join(out_dir, f"frame_{frame_idx:04d}.jpg")
            cv2.imwrite(out_path, frame)
            frame_idx += 1
        frame_count += 1
    cap.release()

# INFERENCE DATASET
class InferenceVideoDataset(Dataset):
    def __init__(self, frame_dir, transform, num_blocks=8, window_sec=5, fps=1):
        files = sorted(os.listdir(frame_dir))
        self.frame_paths = [os.path.join(frame_dir, f) for f in files if f.endswith('.jpg')]
        self.transform = transform
        self.num_blocks = num_blocks
        self.window_sec = window_sec
        self.fps = fps
        self.num_frames_per_block = window_sec * fps

        total = len(self.frame_paths)
        block_size = total // num_blocks if num_blocks > 0 else total

        self.samples = []
        for b in range(num_blocks):
            block_start = b * block_size
            start = block_start
            end = min(start + self.num_frames_per_block, (b+1)*block_size, total)
            block_indices = list(range(start, end))
            while len(block_indices) < self.num_frames_per_block:
                block_indices.append(-1)
            paths = [self.frame_paths[i] if i != -1 else None for i in block_indices]
            self.samples.append(paths)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        paths = self.samples[idx]
        clips = []
        for p in paths:
            if p is None:
                img = np.zeros((256,256,3), np.uint8)
            else:
                img = cv2.imread(p)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            if self.transform:
                img = self.transform(img)
            clips.append(img)
        clip_tensor = torch.stack(clips, dim=0)
        return clip_tensor

def run_inference_on_video(video_path, model, device, tmp_dir, num_blocks, window_sec, fps):
    # Clean up and create tmp frame directory
    if os.path.exists(tmp_dir):
        for f in os.listdir(tmp_dir):
            os.remove(os.path.join(tmp_dir, f))
    else:
        os.makedirs(tmp_dir)
    extract_frames(video_path, tmp_dir, fps=fps)
    ds = InferenceVideoDataset(tmp_dir, frame_transform, num_blocks=num_blocks, window_sec=window_sec, fps=fps)
    loader = DataLoader(ds, batch_size=1, shuffle=False)

    all_preds = []
    with torch.no_grad():
        for xb in loader:
            xb = xb.to(device)
            out = model(xb)
            preds = out.argmax(-1).squeeze(0)
            all_preds.append(preds.cpu().numpy())
    all_preds = np.concatenate(all_preds, axis=0)
    return all_preds  # [num_blocks, NUM_OUTPUTS]

def aggregate_predictions(preds):
    # Majority vote (mode) per OSATS metric
    video_pred_mode, _ = scipy.stats.mode(preds, axis=0, keepdims=False)
    video_pred_mode = video_pred_mode.astype(int).tolist()
    return video_pred_mode

def main():
    # Get all video files
    all_videos = [os.path.join(VIDEO_DIR, f) for f in os.listdir(VIDEO_DIR)
                  if f.lower().endswith(('.mp4', '.avi', '.mov', '.mkv'))]
    print(f"Found {len(all_videos)} videos in {VIDEO_DIR}")

    # Randomly select N videos
    selected_videos = random.sample(all_videos, min(NUM_VIDEOS, len(all_videos)))

    # Load Excel and model
    df = pd.read_excel(EXCEL_FP)
    model = OSATSVideoResNetModel(num_outputs=NUM_OUTPUTS, num_classes=NUM_CLASSES)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model = model.to(DEVICE)
    model.eval()

    # For metrics over all videos
    all_video_preds = []
    all_video_gts = []

    for i, video_path in enumerate(selected_videos):
        print(f"\n========== Video {i+1}: {os.path.basename(video_path)} ==========")

        # Try to find VIDEO_ID for this video (assumes filename or path contains video id)
        base_video_id = os.path.splitext(os.path.basename(video_path))[0]
        row = df[df['VIDEO'].astype(str) == str(base_video_id)]
        if row.empty:
            print(f"[WARNING] No ground-truth found in Excel for VIDEO={base_video_id}. Skipping.")
            continue
        gt_scores = row.iloc[0][TARGET_COLS].values.astype(int)

        # Run inference
        preds = run_inference_on_video(video_path, model, DEVICE, FRAME_TMP_DIR,
                                       num_blocks=NUM_BLOCKS, window_sec=WINDOW_SEC, fps=FPS)

        # Aggregate: single prediction per video
        video_pred_mode = aggregate_predictions(preds)
        all_video_preds.append(video_pred_mode)
        all_video_gts.append(gt_scores)

        print("Ground truth:")
        for j, col in enumerate(TARGET_COLS):
            print(f"  {col}: {int(gt_scores[j])}")
        print("Aggregated prediction (majority vote):")
        for j, col in enumerate(TARGET_COLS):
            print(f"  {col}: {video_pred_mode[j]} (GT: {int(gt_scores[j])})")

    if len(all_video_preds) == 0:
        print("No videos processed!")
        return

    # Convert to numpy arrays for metrics
    all_video_preds = np.stack(all_video_preds, axis=0)
    all_video_gts = np.stack(all_video_gts, axis=0)

    # Calculate and print per-target accuracy
    print("\n=== Per-target accuracy across all videos ===")
    for j, col in enumerate(TARGET_COLS):
        acc = np.mean(all_video_preds[:, j] == all_video_gts[:, j])
        print(f"  {col}: {acc:.2f}")

    # Optionally, print overall accuracy (all metrics correct at once)
    all_correct = np.all(all_video_preds == all_video_gts, axis=1)
    overall_acc = np.mean(all_correct)
    print(f"\nOverall video-level accuracy (all OSATS correct): {overall_acc:.2f}")

if __name__ == "__main__":
    main()


Found 24 videos in ./data/videos/val

Ground truth:
  GRS: 1
Aggregated prediction (majority vote):
  GRS: 2 (GT: 1)

Ground truth:
  GRS: 2
Aggregated prediction (majority vote):
  GRS: 2 (GT: 2)

Ground truth:
  GRS: 2
Aggregated prediction (majority vote):
  GRS: 2 (GT: 2)

Ground truth:
  GRS: 1
Aggregated prediction (majority vote):
  GRS: 0 (GT: 1)

Ground truth:
  GRS: 2
Aggregated prediction (majority vote):
  GRS: 2 (GT: 2)

Ground truth:
  GRS: 1
Aggregated prediction (majority vote):
  GRS: 1 (GT: 1)

Ground truth:
  GRS: 2
Aggregated prediction (majority vote):
  GRS: 2 (GT: 2)

Ground truth:
  GRS: 2
Aggregated prediction (majority vote):
  GRS: 2 (GT: 2)

Ground truth:
  GRS: 0
Aggregated prediction (majority vote):
  GRS: 0 (GT: 0)

Ground truth:
  GRS: 1
Aggregated prediction (majority vote):
  GRS: 1 (GT: 1)

Ground truth:
  GRS: 0
Aggregated prediction (majority vote):
  GRS: 0 (GT: 0)

Ground truth:
  GRS: 1
Aggregated prediction (majority vote):
  GRS: 1 (GT: 1)

Gr