# Video Classification with ResNet and LSTM

This notebook demonstrates a workflow for video classification using a combination of a ResNet-34 model for feature extraction and an LSTM (Long Short-Term Memory) for temporal modeling.

## Overview

1. **Environment Setup**
2. **Model Architecture**
3. **Dataset Preparation**
4. **Training Pipeline**
5. **Validation and Evaluation**
6. **Saving the Final Model**

## Output
- The notebook outputs the trained model, saved checkpoints, and validation accuracy. The final model is saved as `resnet_lstm_highlight_model.pth`.

In [None]:
import os
import json
import cv2
from datetime import datetime
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm


# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ResNet-34 feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, fine_tune: bool = False):
        super().__init__()
        backbone = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # strip FC
        self.fine_tune = fine_tune
        if not fine_tune:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, x):  # (N, 3, 224, 224)
        if self.fine_tune:
            feats = self.backbone(x)
        else:
            with torch.no_grad():
                feats = self.backbone(x)
        return feats.view(x.size(0), -1)  # (N, 512)

# ResNet-34 + LSTM classifier
class LSTMWithResNet(nn.Module):
    def __init__(self, feature_size: int, hidden_size: int, output_size: int,
                 num_layers: int = 2, dropout: float = 0.3, fine_tune_cnn: bool = False):
        super().__init__()
        self.feature_extractor = ResNetFeatureExtractor(fine_tune=fine_tune_cnn) # Create ResNet feature extractor
        self.lstm = nn.LSTM(feature_size,
                            hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            dropout=dropout if num_layers > 1 else 0.0) # Add LSTM layer
        # Create a sequential classifier with two linear layers and ReLU activation
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 2, output_size)
        )

    def forward(self, x):  # (B, T, C, H, W)
        b, t, c, h, w = x.size()
        x = x.view(-1, c, h, w)                      # (B*T, C, H, W)
        feats = self.feature_extractor(x)            # (B*T, 512)
        feats = feats.view(b, t, -1)                 # (B, T, 512)
        lstm_out, _ = self.lstm(feats)               # (B, T, H)
        logits = self.classifier(lstm_out[:, -1, :]) # last step
        return logits

# Video dataset class
class VideoDataset(Dataset):
    def __init__(self,
                 video_paths: List[str],
                 labels: List[str],
                 label_to_index: dict,
                 max_frames: int = 64,
                 transform=None):
        self.video_paths = video_paths # list of video file paths
        self.labels = labels          # list of labels
        self.label_to_index = label_to_index # mapping from labels to indices
        self.max_frames = max_frames  # maximum number of frames to sample
        self.transform = transform   # optional transform to apply to each frame

    # Get length of dataset
    def __len__(self):
        return len(self.video_paths)

    # Sample frames from the video
    def _sample_frames(self, frames: List[np.ndarray]) -> List[np.ndarray]:
        if len(frames) <= self.max_frames:
            return frames
        # uniform sampling
        idxs = np.linspace(0, len(frames) - 1, self.max_frames, dtype=int)
        return [frames[i] for i in idxs]

    # Get item from dataset
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]

        frames = []
        cap = cv2.VideoCapture(video_path) # Open video file
        while cap.isOpened():
            ret, frame = cap.read() # Read frame
            if not ret:
                break
            frame = cv2.resize(frame, (224, 224))
            frame = frame[:, :, ::-1]  # BGR→RGB
            frame = frame / 255.0 # Normalize
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        cap.release()

        if not frames:
            raise RuntimeError(f"Could not read frames from {video_path}")

        frames = self._sample_frames(frames)
        frames = np.stack(frames)                      # (T, H, W, C)
        frames_tensor = torch.tensor(frames).permute(0, 3, 1, 2).float()
        label_idx = self.label_to_index[label] # convert label to index
        return frames_tensor, label_idx

# Load video data
def load_data(root_directory):
    video_paths, labels = [], []
    for folder_name in os.listdir(root_directory):
        folder_path = os.path.join(root_directory, folder_name) # Get path to folder
        if os.path.isdir(folder_path):
            for fname in os.listdir(folder_path): # Get all files in folder
                if fname.lower().endswith((".mp4", ".avi", ".mov")): # Check if file is a video
                    video_paths.append(os.path.join(folder_path, fname)) # Add video path to list
                    labels.append(folder_name) # Add label to list
    return video_paths, labels


root_directory = r""  # Path to the dataset directory
video_paths, labels = load_data(root_directory) # Load video data

unique_labels = sorted(set(labels)) # Get unique labels
label_to_index = {lbl: i for i, lbl in enumerate(unique_labels)}
index_to_label = {i: lbl for lbl, i in label_to_index.items()}

# Split data into training and validation sets
train_paths, val_paths, train_labels, val_labels = train_test_split(
    video_paths, labels, test_size=0.2, random_state=42, stratify=labels)

# Hyperparameters
feature_size = 512
hidden_size = 512
output_size = len(unique_labels) # number of unique labels
num_layers = 2
num_epochs = 20
batch_size = 2  # reduce if still OOM
learning_rate = 1e-4
checkpoint_dir = "checkpoints" # directory to save checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)

# Create datasets and dataloaders
train_dataset = VideoDataset(train_paths, train_labels, label_to_index)
val_dataset = VideoDataset(val_paths,   val_labels,   label_to_index)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  pin_memory=True)
val_loader = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, pin_memory=True)

# Create model, loss function, optimizer
model = LSTMWithResNet(feature_size, hidden_size, output_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss() # Cross entropy loss
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) # Adam optimizer
scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for vids, lbls in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        vids, lbls = vids.to(device, non_blocking=True), lbls.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            outputs = model(vids)
            loss = criterion(outputs, lbls)
        scaler.scale(loss).backward() # Backpropagation
        scaler.step(optimizer) # Update weights
        scaler.update() # Update scaler
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader) # Average loss
    print(f"Epoch {epoch+1:02d} | train loss: {avg_loss:.4f}")

    ckpt_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1:02d}.pth")
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_loss": avg_loss,
        # "val_acc": val_acc,
        "timestamp": datetime.utcnow().isoformat()
    }, ckpt_path) # Save checkpoint
    print(f"✔ Saved checkpoint → {ckpt_path}\n")

torch.save(model.state_dict(), "resnet_lstm_highlight_model.pth") # Save final model
print("Training complete – final model saved as 'resnet_lstm_highlight_model.pth'.")


model.eval() # Set model to evaluation mode
correct = total = 0
with torch.no_grad():
    for vids, lbls in val_loader:
        vids, lbls = vids.to(device, non_blocking=True), lbls.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            outputs = model(vids)
        preds = outputs.argmax(dim=1) # Get predicted labels
        correct += (preds == lbls).sum().item() # Count correct predictions
        total   += lbls.size(0) # Count total predictions
val_acc = correct / total if total else 0.0 # Calculate accuracy
print(f"Epoch {epoch+1:02d} | val acc : {val_acc:.4f}")

Using device: cuda
GPU: Tesla T4


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 186MB/s] 
  scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")
  with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
Epoch 1/20: 100%|██████████| 562/562 [21:59<00:00,  2.35s/it]


Epoch 01 | train loss: 2.5090
✔ Saved checkpoint → checkpoints/checkpoint_epoch_01.pth



Epoch 2/20: 100%|██████████| 562/562 [21:18<00:00,  2.27s/it]


Epoch 02 | train loss: 2.2383
✔ Saved checkpoint → checkpoints/checkpoint_epoch_02.pth



Epoch 3/20: 100%|██████████| 562/562 [21:01<00:00,  2.24s/it]


Epoch 03 | train loss: 2.0851
✔ Saved checkpoint → checkpoints/checkpoint_epoch_03.pth



Epoch 4/20: 100%|██████████| 562/562 [21:03<00:00,  2.25s/it]


Epoch 04 | train loss: 2.0570
✔ Saved checkpoint → checkpoints/checkpoint_epoch_04.pth



Epoch 5/20: 100%|██████████| 562/562 [20:58<00:00,  2.24s/it]


Epoch 05 | train loss: 1.9392
✔ Saved checkpoint → checkpoints/checkpoint_epoch_05.pth



Epoch 6/20: 100%|██████████| 562/562 [20:55<00:00,  2.23s/it]


Epoch 06 | train loss: 1.9263
✔ Saved checkpoint → checkpoints/checkpoint_epoch_06.pth



Epoch 7/20: 100%|██████████| 562/562 [21:03<00:00,  2.25s/it]


Epoch 07 | train loss: 1.8315
✔ Saved checkpoint → checkpoints/checkpoint_epoch_07.pth



Epoch 8/20: 100%|██████████| 562/562 [21:04<00:00,  2.25s/it]


Epoch 08 | train loss: 1.7468
✔ Saved checkpoint → checkpoints/checkpoint_epoch_08.pth



Epoch 9/20: 100%|██████████| 562/562 [21:17<00:00,  2.27s/it]


Epoch 09 | train loss: 1.6805
✔ Saved checkpoint → checkpoints/checkpoint_epoch_09.pth



Epoch 10/20: 100%|██████████| 562/562 [21:08<00:00,  2.26s/it]


Epoch 10 | train loss: 1.6074
✔ Saved checkpoint → checkpoints/checkpoint_epoch_10.pth



Epoch 11/20: 100%|██████████| 562/562 [20:59<00:00,  2.24s/it]


Epoch 11 | train loss: 1.5755
✔ Saved checkpoint → checkpoints/checkpoint_epoch_11.pth



Epoch 12/20: 100%|██████████| 562/562 [21:04<00:00,  2.25s/it]


Epoch 12 | train loss: 1.5000
✔ Saved checkpoint → checkpoints/checkpoint_epoch_12.pth



Epoch 13/20: 100%|██████████| 562/562 [21:08<00:00,  2.26s/it]


Epoch 13 | train loss: 1.4526
✔ Saved checkpoint → checkpoints/checkpoint_epoch_13.pth



Epoch 14/20: 100%|██████████| 562/562 [21:22<00:00,  2.28s/it]


Epoch 14 | train loss: 1.4135
✔ Saved checkpoint → checkpoints/checkpoint_epoch_14.pth



Epoch 15/20: 100%|██████████| 562/562 [21:07<00:00,  2.26s/it]


Epoch 15 | train loss: 1.3621
✔ Saved checkpoint → checkpoints/checkpoint_epoch_15.pth



Epoch 16/20: 100%|██████████| 562/562 [21:17<00:00,  2.27s/it]


Epoch 16 | train loss: 1.2911
✔ Saved checkpoint → checkpoints/checkpoint_epoch_16.pth



Epoch 17/20: 100%|██████████| 562/562 [21:16<00:00,  2.27s/it]


Epoch 17 | train loss: 1.2369
✔ Saved checkpoint → checkpoints/checkpoint_epoch_17.pth



Epoch 18/20: 100%|██████████| 562/562 [21:04<00:00,  2.25s/it]


Epoch 18 | train loss: 1.1705
✔ Saved checkpoint → checkpoints/checkpoint_epoch_18.pth



Epoch 19/20: 100%|██████████| 562/562 [21:09<00:00,  2.26s/it]


Epoch 19 | train loss: 1.1798
✔ Saved checkpoint → checkpoints/checkpoint_epoch_19.pth



Epoch 20/20: 100%|██████████| 562/562 [21:19<00:00,  2.28s/it]


Epoch 20 | train loss: 1.1702
✔ Saved checkpoint → checkpoints/checkpoint_epoch_20.pth

Training complete – final model saved as 'resnet_lstm_highlight_model.pth'.


  with torch.cuda.amp.autocast(enabled=device.type == "cuda"):


Epoch 20 | val acc : 0.5638
