In [9]:
%cd /home/hice1/mdoutre3/CS7643_Project_1

/home/hice1/mdoutre3/CS7643_Project_1


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from glob import glob
import math

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

Running on: cuda


In [6]:
import json

In [None]:

# =============================================================================
# MODEL COMPONENTS
# =============================================================================

class PositionalEncoding(nn.Module):
    """Positional encoding for temporal sequences."""
    def __init__(self, d_model, max_len=600, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class MultimodalFusionTransformer(nn.Module):
    def __init__(self, input_dim=768, d_model=256, nhead=4, num_layers=2,
                 num_classes=15, dropout=0.1, max_seq_len=100):
        super().__init__()
        
        # Much simpler projection with heavy dropout
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, d_model),
            nn.Dropout(dropout)
        )
        
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len + 2, dropout=dropout)
        
        # Single lightweight transformer layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model,  # Same size as d_model (minimal capacity)
            dropout=dropout,
            batch_first=True,
            norm_first=True  # Pre-norm for better training
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Simple classification head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )
    
    def forward(self, video_x, video_mask, text_emb):
        """
        video_x: (B, T, 768)
        video_mask: (B, T) - True for valid positions
        text_emb: (B, 768)
        """
        B = video_x.size(0)
        
        # Project inputs
        video_x = self.input_proj(video_x)
        text_x = self.input_proj(text_emb).unsqueeze(1)
        
        # Concatenate text as CLS token
        x = torch.cat([text_x, video_x], dim=1)  # (B, T+1, d_model)
        
        # Create full mask
        text_mask = torch.ones(B, 1, dtype=torch.bool, device=video_mask.device)
        full_mask = torch.cat([text_mask, video_mask], dim=1)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Transformer encoding
        x = self.encoder(x, src_key_padding_mask=~full_mask)
        
        # Use CLS token for classification
        cls = x[:, 0]
        return self.classifier(cls)


# =============================================================================
# AUGMENTED DATASET WITH MIXUP
# =============================================================================

class SoccerFusionDataset(Dataset):
    """Dataset for video+text fusion using matched pairs."""
    
    def __init__(self, video_paths, text_paths, labels, max_seq_len=100, 
                 augment=False):
        self.video_paths = video_paths
        self.text_paths = text_paths
        self.labels = labels
        self.max_seq_len = max_seq_len
        self.augment = augment
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        # Load video features
        video = np.load(self.video_paths[idx])  # (T, 768)
        video = torch.from_numpy(video).float()
        T = video.shape[0]
        
        # Random temporal augmentation during training
        if self.augment and T > 20:
            # Randomly drop up to 30% of frames
            keep_ratio = np.random.uniform(0.7, 1.0)
            keep_frames = max(20, int(T * keep_ratio))
            start_idx = np.random.randint(0, T - keep_frames + 1)
            video = video[start_idx:start_idx + keep_frames]
            T = keep_frames
        
        # Pad or truncate video
        if T > self.max_seq_len:
            # Random crop instead of always taking first frames
            if self.augment:
                start = np.random.randint(0, T - self.max_seq_len + 1)
                video = video[start:start + self.max_seq_len]
            else:
                video = video[:self.max_seq_len]
            mask = torch.ones(self.max_seq_len, dtype=torch.bool)
        else:
            pad = torch.zeros(self.max_seq_len - T, 768)
            video = torch.cat([video, pad], dim=0)
            mask = torch.cat([
                torch.ones(T, dtype=torch.bool),
                torch.zeros(self.max_seq_len - T, dtype=torch.bool)
            ])
        
        # Add Gaussian noise for augmentation
        if self.augment:
            video = video + torch.randn_like(video) * 0.01
        
        # Load text embedding
        text_data = torch.load(self.text_paths[idx])
        text_emb = text_data["embedding"].squeeze(0)  # (768,)
        
        if self.augment:
            text_emb = text_emb + torch.randn_like(text_emb) * 0.01
        
        return video, mask, text_emb, self.labels[idx]


# =============================================================================
# DATA LOADING & PREPROCESSING
# =============================================================================

def match_video_text_pairs(video_paths, text_paths):
    """Match video files with corresponding text embeddings."""
    text_lookup = {Path(t).stem: t for t in text_paths}
    
    def get_base_key(video_path):
        stem = Path(video_path).stem
        parts = stem.split("_")
        return "_".join(parts[:-1])  # Remove event name
    
    matched_pairs = []
    for v in video_paths:
        base = get_base_key(v)
        matches = [t for t in text_paths if Path(t).stem.startswith(base)]
        if matches:
            matched_pairs.append((v, matches[0]))
    
    return matched_pairs


def extract_event_label(video_path):
    """Extract event type from video filename."""
    stem = Path(video_path).stem
    return stem.split("_")[-1]


def prepare_dataset(video_dir="fusion/embeddings 2", 
                   text_dir="fusion/text_embeddings_events",
                   max_seq_len=100):
    """Prepare matched dataset with labels."""
    
    video_paths = sorted(glob(f"{video_dir}/*.npy"))
    text_paths = sorted(glob(f"{text_dir}/*.pt"))
    
    print(f"Found {len(video_paths)} video files")
    print(f"Found {len(text_paths)} text files")
    
    # Match pairs
    #matched_pairs = match_video_text_pairs(video_paths, text_paths)
    

    if Path("fusion/matched_pairs.json").exists():
        print("Loading cached matched pairs...")
        with open("fusion/matched_pairs.json", "r") as f:
            matched_pairs = [tuple(x) for x in json.load(f)]

    else:
        print("Computing matched pairs ...")
        matched_pairs = match_video_text_pairs(video_paths, text_paths)
        print(f"Matched {len(matched_pairs)} pairs")
        
        print("Saving matched pairs...")
        with open("fusion/matched_pairs.json", "w") as f:
            json.dump([(v, t) for v, t in matched_pairs], f)

    
    # Extract labels
    video_paths_matched = [v for v, t in matched_pairs]
    text_paths_matched = [t for v, t in matched_pairs]
    event_types = [extract_event_label(v) for v in video_paths_matched]
    
    # Create label mapping
    unique_events = sorted(set(event_types))
    event_to_idx = {ev: i for i, ev in enumerate(unique_events)}
    labels = [event_to_idx[ev] for ev in event_types]
    
    print(f"\nEvent classes ({len(unique_events)}):")
    class_counts = {}
    for ev, idx in event_to_idx.items():
        count = sum(1 for l in labels if l == idx)
        class_counts[idx] = count
        print(f"  {idx:2d}: {ev:25s} ({count} samples)")
    
    # Return the base data and metadata
    return (video_paths_matched, text_paths_matched, labels, 
            event_to_idx, class_counts, max_seq_len)


# =============================================================================
# TRAINING WITH MIXUP
# =============================================================================

def mixup_data(x_video, x_mask, x_text, y, alpha=0.2):
    """Mixup augmentation."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x_video.size(0)
    index = torch.randperm(batch_size, device=x_video.device)

    mixed_video = lam * x_video + (1 - lam) * x_video[index]
    mixed_text = lam * x_text + (1 - lam) * x_text[index]
    # Keep original mask for simplicity
    
    y_a, y_b = y, y[index]
    return mixed_video, x_mask, mixed_text, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Mixup loss calculation."""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_epoch(model, loader, criterion, optimizer, device, use_mixup=True):
    """Train for one epoch with optional mixup."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for video, mask, text_emb, y in loader:
        video = video.to(device)
        mask = mask.to(device)
        text_emb = text_emb.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        
        # Apply mixup with 50% probability
        if use_mixup and np.random.rand() > 0.5:
            video, mask, text_emb, y_a, y_b, lam = mixup_data(
                video, mask, text_emb, y, alpha=0.2
            )
            logits = model(video, mask, text_emb)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
        else:
            logits = model(video, mask, text_emb)
            loss = criterion(logits, y)
        
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        
        optimizer.step()
        
        total_loss += loss.item()
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
    
    return total_loss / len(loader), correct / total


def validate(model, loader, criterion, device):
    """Validate model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for video, mask, text_emb, y in loader:
            video = video.to(device)
            mask = mask.to(device)
            text_emb = text_emb.to(device)
            y = y.to(device)
            
            logits = model(video, mask, text_emb)
            loss = criterion(logits, y)
            
            total_loss += loss.item()
            correct += (logits.argmax(1) == y).sum().item()
            total += y.size(0)
    
    return total_loss / len(loader), correct / total


def train_model(model, train_loader, val_loader, epochs=100, lr=5e-4, 
                patience=15, device='cuda', class_counts=None):
    """Train with early stopping and class balancing."""
    
    model.to(device)
    
    # Weighted loss for class imbalance
    if class_counts:
        # Correct way to extract class frequencies
        freqs = [class_counts[i] for i in range(len(class_counts))]

        weights = torch.tensor(freqs, dtype=torch.float)
        weights = 1.0 / weights            # inverse frequency
        weights = weights / weights.sum()  # normalize
        weights = weights.to(device)

        #criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    else:
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    
    # Cosine annealing with warm restarts
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    #    optimizer, T_0=10, T_mult=2, eta_min=1e-6
    #)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=30,  # the total number of epochs
        eta_min=1e-6
    )
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, use_mixup=False
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        print(f"[Epoch {epoch+1:03d}] "
              f"Train Loss={train_loss:.4f}, Acc={train_acc:.3f} | "
              f"Val Loss={val_loss:.4f}, Acc={val_acc:.3f} | "
              f"LR={optimizer.param_groups[0]['lr']:.2e}")
        
        scheduler.step()
        
        # Early stopping based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pt')
            print(f"  → Saved best model (val_loss={val_loss:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pt'))
    return model



In [12]:
if __name__ == "__main__":
    # Prepare data
    (video_paths, text_paths, labels, event_to_idx, 
     class_counts, max_seq_len) = prepare_dataset(max_seq_len=100)
    
    # Split indices
    indices = list(range(len(labels)))
    np.random.seed(42)
    np.random.shuffle(indices)
    
    train_size = int(0.8 * len(indices))
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    
    # Create separate datasets for train and val
    train_dataset = SoccerFusionDataset(
        video_paths=[video_paths[i] for i in train_indices],
        text_paths=[text_paths[i] for i in train_indices],
        labels=[labels[i] for i in train_indices],
        max_seq_len=max_seq_len,
        augment=False
    )
    
    val_dataset = SoccerFusionDataset(
        video_paths=[video_paths[i] for i in val_indices],
        text_paths=[text_paths[i] for i in val_indices],
        labels=[labels[i] for i in val_indices],
        max_seq_len=max_seq_len,
        augment=False
    )
    
    train_subset = train_dataset
    val_subset = val_dataset
    
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, 
                             num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_subset, batch_size=32, shuffle=False, 
                           num_workers=4, pin_memory=True)
    
    print(f"\nTrain: {len(train_subset)} samples ({len(train_loader)} batches)")
    print(f"Val:   {len(val_subset)} samples ({len(val_loader)} batches)")
    
    # =========================================================================
    # BASELINE EXPERIMENTS: Test text-only vs video-only
    # =========================================================================
    print("\n" + "="*70)
    print("BASELINE EXPERIMENTS")
    print("="*70)
      
    # TEXT-ONLY MODEL
    print("\n[1/3] Training TEXT-ONLY baseline...")
    class TextOnlyModel(nn.Module):
        def __init__(self, input_dim=768, hidden_dim=256, num_classes=15, dropout=0.1):
            super().__init__()
            self.classifier = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, num_classes)
            )
        
        def forward(self, video_x, video_mask, text_emb):
            return self.classifier(text_emb)
    
    text_model = TextOnlyModel(num_classes=len(event_to_idx), dropout=0.1)
    text_model = train_model(text_model, train_loader, val_loader, epochs=50, 
                            lr=1e-3, patience=10, device=device, 
                            class_counts=class_counts)
    text_val_loss, text_val_acc = validate(text_model, val_loader, 
                                           nn.CrossEntropyLoss(), device)
    print(f"TEXT-ONLY Val Acc: {text_val_acc:.3f}")
    
    # VIDEO-ONLY MODEL
    print("\n[2/3] Training VIDEO-ONLY baseline...")
    class VideoOnlyModel(nn.Module):
        def __init__(self, input_dim=768, d_model=256, nhead=4, num_layers=2,
                     num_classes=15, dropout=0.1, max_seq_len=100):
            super().__init__()
            self.input_proj = nn.Linear(input_dim, d_model)
            self.pos_encoder = PositionalEncoding(d_model, max_seq_len, dropout=0.1)
            
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model, nhead=nhead, dim_feedforward=d_model * 2,
                dropout=dropout, batch_first=True
            )
            self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
            self.classifier = nn.Sequential(
                nn.Linear(d_model, d_model // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(d_model // 2, num_classes)
            )
        
        def forward(self, video_x, video_mask, text_emb):
            x = self.input_proj(video_x)
            x = self.pos_encoder(x)
            x = self.encoder(x, src_key_padding_mask=~video_mask)
            x = x.mean(dim=1)  # Average pooling
            return self.classifier(x)
    
    video_model = VideoOnlyModel(num_classes=len(event_to_idx), dropout=0.1)
    video_model = train_model(video_model, train_loader, val_loader, epochs=50, 
                             lr=1e-3, patience=15, device=device,
                             class_counts=class_counts)
    video_val_loss, video_val_acc = validate(video_model, val_loader, 
                                            nn.CrossEntropyLoss(), device)
    print(f"VIDEO-ONLY Val Acc: {video_val_acc:.3f}")
    
    # FUSION MODEL
    print("\n[3/3] Training FUSION model...")
    model = MultimodalFusionTransformer(
        input_dim=768,
        d_model=256,
        nhead=4,
        num_layers=2,
        num_classes=len(event_to_idx),
        dropout=0.1,
        max_seq_len=100
    )
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {num_params:,}")
    
    model = train_model(
        model, train_loader, val_loader,
        epochs=30,
        lr=1e-3,
        patience=15,
        device=device,
        class_counts=class_counts
    )
    
    fusion_val_loss, fusion_val_acc = validate(
        model, val_loader, nn.CrossEntropyLoss(), device
    )
    
    # SUMMARY
    print("\n" + "="*70)
    print("RESULTS SUMMARY")
    print("="*70)
    print(f"Text-only:   {text_val_acc:.3f}")
    print(f"Video-only:  {video_val_acc:.3f}")
    print(f"Fusion:      {fusion_val_acc:.3f}")
    print(f"\nFusion improvement over text: {(fusion_val_acc - text_val_acc):.3f}")
    print(f"Fusion improvement over video: {(fusion_val_acc - video_val_acc):.3f}")
    print("="*70)

Found 5820 video files
Found 5832 text files
Loading cached matched pairs...

Event classes (16):
   0: Ball out of play          (1676 samples)
   1: Clearance                 (433 samples)
   2: Corner                    (249 samples)
   3: Direct free-kick          (105 samples)
   4: Foul                      (611 samples)
   5: Goal                      (118 samples)
   6: Indirect free-kick        (487 samples)
   7: Kick-off                  (136 samples)
   8: Offside                   (140 samples)
   9: Penalty                   (6 samples)
  10: Red card                  (7 samples)
  11: Shots off target          (265 samples)
  12: Shots on target           (299 samples)
  13: Substitution              (146 samples)
  14: Throw-in                  (1017 samples)
  15: Yellow card               (121 samples)





Train: 4652 samples (145 batches)
Val:   1164 samples (37 batches)

BASELINE EXPERIMENTS

[1/3] Training TEXT-ONLY baseline...




[Epoch 001] Train Loss=1.9528, Acc=0.451 | Val Loss=1.7545, Acc=0.500 | LR=1.00e-03
  → Saved best model (val_loss=1.7545)
[Epoch 002] Train Loss=1.6858, Acc=0.541 | Val Loss=1.7462, Acc=0.489 | LR=9.97e-04
  → Saved best model (val_loss=1.7462)
[Epoch 003] Train Loss=1.6095, Acc=0.574 | Val Loss=1.7354, Acc=0.515 | LR=9.89e-04
  → Saved best model (val_loss=1.7354)
[Epoch 004] Train Loss=1.5171, Acc=0.608 | Val Loss=1.7654, Acc=0.517 | LR=9.76e-04
[Epoch 005] Train Loss=1.4145, Acc=0.649 | Val Loss=1.8337, Acc=0.487 | LR=9.57e-04
[Epoch 006] Train Loss=1.3090, Acc=0.701 | Val Loss=1.9198, Acc=0.486 | LR=9.33e-04
[Epoch 007] Train Loss=1.2110, Acc=0.748 | Val Loss=1.9650, Acc=0.494 | LR=9.05e-04
[Epoch 008] Train Loss=1.1199, Acc=0.794 | Val Loss=2.0060, Acc=0.496 | LR=8.72e-04
[Epoch 009] Train Loss=1.0219, Acc=0.835 | Val Loss=2.1035, Acc=0.493 | LR=8.35e-04
[Epoch 010] Train Loss=0.9693, Acc=0.856 | Val Loss=2.1662, Acc=0.475 | LR=7.94e-04
[Epoch 011] Train Loss=0.9153, Acc=0.882 | 

  output = torch._nested_tensor_from_mask(


[Epoch 001] Train Loss=2.2441, Acc=0.340 | Val Loss=2.0704, Acc=0.405 | LR=1.00e-03
  → Saved best model (val_loss=2.0704)
[Epoch 002] Train Loss=1.8682, Acc=0.440 | Val Loss=1.9057, Acc=0.415 | LR=9.97e-04
  → Saved best model (val_loss=1.9057)
[Epoch 003] Train Loss=1.7074, Acc=0.500 | Val Loss=1.8079, Acc=0.487 | LR=9.89e-04
  → Saved best model (val_loss=1.8079)
[Epoch 004] Train Loss=1.6458, Acc=0.531 | Val Loss=1.7626, Acc=0.524 | LR=9.76e-04
  → Saved best model (val_loss=1.7626)
[Epoch 005] Train Loss=1.5834, Acc=0.558 | Val Loss=1.7508, Acc=0.536 | LR=9.57e-04
  → Saved best model (val_loss=1.7508)
[Epoch 006] Train Loss=1.5548, Acc=0.569 | Val Loss=1.6981, Acc=0.536 | LR=9.33e-04
  → Saved best model (val_loss=1.6981)
[Epoch 007] Train Loss=1.5081, Acc=0.598 | Val Loss=1.6798, Acc=0.549 | LR=9.05e-04
  → Saved best model (val_loss=1.6798)
[Epoch 008] Train Loss=1.4689, Acc=0.609 | Val Loss=1.6354, Acc=0.582 | LR=8.72e-04
  → Saved best model (val_loss=1.6354)
[Epoch 009] Trai



[Epoch 001] Train Loss=2.0547, Acc=0.402 | Val Loss=1.8137, Acc=0.482 | LR=1.00e-03
  → Saved best model (val_loss=1.8137)
[Epoch 002] Train Loss=1.6391, Acc=0.554 | Val Loss=1.6118, Acc=0.570 | LR=9.97e-04
  → Saved best model (val_loss=1.6118)
[Epoch 003] Train Loss=1.4738, Acc=0.622 | Val Loss=1.5136, Acc=0.637 | LR=9.89e-04
  → Saved best model (val_loss=1.5136)
[Epoch 004] Train Loss=1.4061, Acc=0.653 | Val Loss=1.4444, Acc=0.637 | LR=9.76e-04
  → Saved best model (val_loss=1.4444)
[Epoch 005] Train Loss=1.3441, Acc=0.680 | Val Loss=1.4441, Acc=0.631 | LR=9.57e-04
  → Saved best model (val_loss=1.4441)
[Epoch 006] Train Loss=1.2721, Acc=0.709 | Val Loss=1.4113, Acc=0.667 | LR=9.33e-04
  → Saved best model (val_loss=1.4113)
[Epoch 007] Train Loss=1.2123, Acc=0.740 | Val Loss=1.3313, Acc=0.692 | LR=9.05e-04
  → Saved best model (val_loss=1.3313)
[Epoch 008] Train Loss=1.1796, Acc=0.755 | Val Loss=1.3545, Acc=0.688 | LR=8.72e-04
[Epoch 009] Train Loss=1.1272, Acc=0.775 | Val Loss=1.3

In [13]:
import numpy as np
from sklearn.metrics import confusion_matrix

def per_class_accuracy(model, loader, device, num_classes):
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for video_x, video_mask, text_emb, labels in loader:
            video_x = video_x.to(device)
            video_mask = video_mask.to(device)
            text_emb = text_emb.to(device)
            labels = labels.to(device)

            logits = model(video_x, video_mask, text_emb)
            preds = logits.argmax(dim=1)

            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
    
    # Per-class accuracy
    per_class = cm.diagonal() / cm.sum(axis=1).clip(min=1)
    
    return per_class, cm


In [14]:
print("\nComputing per-class performance...\n")

num_classes = len(event_to_idx)
idx_to_event = {v: k for k, v in event_to_idx.items()}

# TEXT-ONLY
text_pc, text_cm = per_class_accuracy(text_model, val_loader, device, num_classes)

# VIDEO-ONLY
video_pc, video_cm = per_class_accuracy(video_model, val_loader, device, num_classes)

# FUSION
fusion_pc, fusion_cm = per_class_accuracy(model, val_loader, device, num_classes)

# Pretty print
print(f"{'Class':25s} | Text  | Video | Fusion")
print("-"*55)
for i in range(num_classes):
    cls = idx_to_event[i][:23]  # shorten
    print(f"{cls:25s} | {text_pc[i]:.3f} | {video_pc[i]:.3f} | {fusion_pc[i]:.3f}")



Computing per-class performance...





Class                     | Text  | Video | Fusion
-------------------------------------------------------
Ball out of play          | 0.809 | 0.834 | 0.868
Clearance                 | 0.247 | 0.588 | 0.600
Corner                    | 0.312 | 0.729 | 0.833
Direct free-kick          | 0.143 | 0.429 | 0.381
Foul                      | 0.504 | 0.568 | 0.799
Goal                      | 0.333 | 0.500 | 0.292
Indirect free-kick        | 0.260 | 0.375 | 0.583
Kick-off                  | 0.125 | 0.469 | 0.344
Offside                   | 0.393 | 0.000 | 0.321
Penalty                   | 0.000 | 0.000 | 0.000
Red card                  | 0.000 | 0.000 | 0.000
Shots off target          | 0.471 | 0.412 | 0.686
Shots on target           | 0.311 | 0.426 | 0.279
Substitution              | 0.643 | 0.643 | 0.857
Throw-in                  | 0.512 | 0.808 | 0.837
Yellow card               | 0.667 | 0.667 | 0.762
