In [None]:
import os, cv2, numpy as np, torch, torch.nn as nn, torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    torch.backends.cudnn.benchmark = True  # speed for fixed‑size inputs


class ResNetFeatureExtractor(nn.Module):
    """ResNet‑34 backbone (output = 512‑D vector)."""
    def __init__(self):
        super().__init__()
        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # cut off FC

    def forward(self, x):                          # (N,3,H,W)
        x = self.backbone(x)                       # (N,512,1,1)
        return x.flatten(1)                        # (N,512)

class LSTMWithResNet(nn.Module):
    """Frame‑level CNN + sequence‑level LSTM classifier."""
    def __init__(self,
                 feature_size: int,
                 hidden_size:  int,
                 output_size:  int,
                 num_layers:   int = 3,
                 dropout:      float = 0.3):
        super().__init__()
        self.feature_extractor = ResNetFeatureExtractor()
        self.lstm = nn.LSTM(
            input_size   = feature_size,
            hidden_size  = hidden_size,
            num_layers   = num_layers,
            batch_first  = True,
            dropout      = dropout if num_layers > 1 else 0.0,
            bidirectional=False
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 2, output_size)
        )

    def forward(self, clip):                       # (B,T,3,H,W)
        B, T, C, H, W = clip.shape
        clip = clip.view(-1, C, H, W)              # (B*T,3,H,W)
        feats = self.feature_extractor(clip)       # (B*T,512)
        feats = feats.view(B, T, -1)               # (B,T,512)
        seq_out, _ = self.lstm(feats)              # (B,T,H)
        last = seq_out[:, -1, :]                   # (B,H)
        return self.fc(last)                       # (B,output)

class VideoDataset(Dataset):
    def __init__(self, video_paths, labels, label_to_index, resize=224, transform=None):
        self.video_paths, self.labels = video_paths, labels
        self.label_to_index, self.transform = label_to_index, transform
        self.resize = resize

    def __len__(self):  return len(self.video_paths)

    def __getitem__(self, idx):
        path, label = self.video_paths[idx], self.labels[idx]
        frames, cap = [], cv2.VideoCapture(path)
        while cap.isOpened():
            ok, frame = cap.read()
            if not ok: break
            frame = cv2.resize(frame, (self.resize, self.resize)) / 255.0
            if self.transform: frame = self.transform(frame)
            frames.append(frame)
        cap.release()

        clip = torch.tensor(np.stack(frames)).permute(0,3,1,2).float()  # (T,3,H,W)
        return clip, self.label_to_index[label]


def load_data(root_dir, exts=(".mp4", ".avi", ".mov", ".mkv")):
    """Recursively collect video paths and labels.
    Assumes folder structure:  root_dir/<label>/<video files>."""
    paths, lbls = [], []
    for lbl in os.listdir(root_dir):
        p = os.path.join(root_dir, lbl)
        if not os.path.isdir(p):
            continue
        for f in os.listdir(p):
            if f.lower().endswith(tuple(e.lower() for e in exts)):
                paths.append(os.path.join(p, f))
                lbls.append(lbl)
    return paths, lbls

root_dir     = "/kaggle/input/match-highlight-extracted/extracted"  
frame_size   = 160            
feature_size = 512
hidden_size  = 512
num_layers   = 3
dropout      = 0.3
epochs       = 20
batch_size   = 1              
accum_steps  = 4              
lr           = 1e-4

# ---- checkpoint settings ----
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
save_every = 2          # save a full checkpoint every N epochs
best_acc = 0.0          # track best validation accuracy


paths, labels = load_data(root_dir)
if len(paths) == 0:
    raise ValueError(f"No video files found under '{root_dir}'. Check path/structure/extensions.")

unique = sorted(set(labels))
label_to_index = {l:i for i,l in enumerate(unique)}
index_to_label = {i:l for l,i in label_to_index.items()}

tr_p, val_p, tr_l, val_l = train_test_split(paths, labels, test_size=0.2, random_state=42, stratify=labels)
train_ds = VideoDataset(tr_p, tr_l, label_to_index, resize=frame_size)
val_ds   = VideoDataset(val_p, val_l, label_to_index, resize=frame_size)
train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_ld   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

model = LSTMWithResNet(feature_size, hidden_size, len(unique), num_layers, dropout).to(device)
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=lr)

# mixed‑precision support
amp_enabled = device.type == "cuda"
scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled)


for epoch in range(1, epochs+1):
    model.train(); running = 0.0
    optimiser.zero_grad(set_to_none=True)

    for step, (clips, lbl) in enumerate(tqdm(train_ld, desc=f"Epoch {epoch}/{epochs}")):
        clips, lbl = clips.to(device, non_blocking=True), lbl.to(device, non_blocking=True)
        try:
            with torch.cuda.amp.autocast(enabled=amp_enabled):
                logits = model(clips)
                loss = criterion(logits, lbl) / accum_steps
            scaler.scale(loss).backward()
        except RuntimeError as e:
            if "out of memory" in str(e):
                print("⚠️  OOM at step", step, "– skipping batch")
                torch.cuda.empty_cache(); continue
            else:
                raise e

        if (step + 1) % accum_steps == 0 or (step + 1) == len(train_ld):
            scaler.step(optimiser)
            scaler.update()
            optimiser.zero_grad(set_to_none=True)
        running += loss.item() * accum_steps  # undo division for logging

    train_loss = running / len(train_ld)
    print(f"  Train loss: {train_loss:.4f}")

    # ---- validation ----
    model.eval(); correct = tot = 0
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp_enabled):
        for clips, lbl in val_ld:
            clips, lbl = clips.to(device, non_blocking=True), lbl.to(device, non_blocking=True)
            preds = model(clips).argmax(1)
            correct += (preds == lbl).sum().item()
            tot += lbl.size(0)
    val_acc = correct / tot
    print(f"  Val  acc : {val_acc:.4f}")

    # ---- checkpointing ----
    if epoch % save_every == 0:
        ckpt_path = os.path.join(checkpoint_dir, f"epoch_{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimiser.state_dict(),
            'val_acc': val_acc,
            'train_loss': train_loss
        }, ckpt_path)
        print(f"  ✓ Checkpoint saved: {ckpt_path}")

    global best_acc
    if val_acc > best_acc:
        best_acc = val_acc
        best_path = os.path.join(checkpoint_dir, "best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimiser.state_dict(),
            'val_acc': val_acc,
            'train_loss': train_loss
        }, best_path)
        print(f"  ✓ New best model ({best_acc:.4f}) saved to {best_path}")

# save final model
final_path = os.path.join(checkpoint_dir, "final.pth")
torch.save(model.state_dict(), final_path)
print(f"✓ Training finished & final model saved to {final_path}\n")

In [None]:
def evaluate_model(model_path, test_dir):
    import matplotlib.pyplot as plt, seaborn as sns
    from sklearn.metrics import confusion_matrix, classification_report

    test_paths, test_labels = load_data(test_dir)
    lbl2idx = {l:i for i,l in enumerate(sorted(set(test_labels)))}
    idx2lbl = {i:l for l,i in lbl2idx.items()}
    test_ds = VideoDataset(test_paths, test_labels, lbl2idx)
    test_ld = DataLoader(test_ds, batch_size=2, shuffle=False)

    # determine number of classes from checkpoint if possible
    ckpt = torch.load(model_path, map_location=device)
    n_classes = ckpt['model_state_dict']['fc.3.weight'].shape[0] if isinstance(ckpt, dict) else len(lbl2idx)

    net = LSTMWithResNet(feature_size, hidden_size, n_classes).to(device)
    if isinstance(ckpt, dict):
        net.load_state_dict(ckpt['model_state_dict'])
    else:
        net.load_state_dict(ckpt)
    net.eval()

    y_true, y_pred = [], []
    with torch.no_grad():
        for clips, lbl in test_ld:
            clips, lbl = clips.to(device), lbl.to(device)
            y_true.extend(lbl.cpu().numpy())
            y_pred.extend(net(clips).argmax(1).cpu().numpy())

    print("\nAccuracy:", (np.array(y_true) == np.array(y_pred)).mean())
    print("\nReport:\n", classification_report(y_true, y_pred, target_names=[idx2lbl[i] for i in idx2lbl]))

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=[idx2lbl[i] for i in idx2lbl],
                yticklabels=[idx2lbl[i] for i in idx2lbl])
    plt.xlabel("Pred"); plt.ylabel("True"); plt.tight_layout(); plt.show()

def predict_video(video_path, model_path):
    ckpt = torch.load(model_path, map_location=device)
    n_classes = ckpt['model_state_dict']['fc.3.weight'].shape[0] if isinstance(ckpt, dict) else len(unique)

    net = LSTMWithResNet(feature_size, hidden_size, n_classes).to(device)
    if isinstance(ckpt, dict):
        net.load_state_dict(ckpt['model_state_dict'])
    else:
        net.load_state_dict(ckpt)
    net.eval()

    frames, cap = [], cv2.VideoCapture(video_path)
    while cap.isOpened():
        ok, f = cap.read()
        if not ok: break
        f = cv2.resize(f, (224,224)) / 255.0
        frames.append(f)
    cap.release()

    clip = torch.tensor(np.stack(frames)).permute(0,3,1,2).float().unsqueeze(0).to(device)
    with torch.no_grad():
        probs = torch.softmax(net(clip), 1)[0]
    idx = probs.argmax().item()
    print(f"Predicted: {index_to_label.get(idx, idx)}  (conf {probs[idx]:.2f})")
    return index_to_label.get(idx, idx), probs.cpu().numpy()
