This notebook classifies an entire video frame-by-frame and allows the user to view the predictions over the images.

In [1]:
import torch
from torchvision import transforms
import torch.nn as nn
from PIL import Image
import os
import cv2
import numpy as np
import pandas as pd
from scipy.stats import mode
from torchvision import models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

  Referenced from: <0B7EB158-53DC-3403-8A49-22178CAB4612> /opt/anaconda3/envs/tensorflow/lib/python3.10/site-packages/torchvision/image.so
  warn(


Load in Model

In [None]:
#transformation that is compatible with resnet-18
class CenterSquareCrop:
    def __call__(self, img):
        width, height = img.size
        min_dim = min(width, height)
        left = (width - min_dim) // 2
        top = (height - min_dim) // 2
        right = left + min_dim
        bottom = top + min_dim
        return img.crop((left, top, right, bottom))

# Updated transform pipeline
transform = transforms.Compose([
    CenterSquareCrop(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean
                         std=[0.229, 0.224, 0.225]),   # ImageNet std
])

# Reconstruct the same architecture
def build_model(num_classes=5):
    model = models.resnet18(pretrained=True)

    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final classification layer
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, num_classes)
    )

    # Unfreeze the classifier head
    for param in model.fc.parameters():
        param.requires_grad = True

    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_model(num_classes=5).to(device)
model.load_state_dict(torch.load('final_model_v1_weight.pth', map_location=device))
model.eval() 

Making predictions

In [None]:
path = "C:/Users/fires/Downloads/russianGPframe-by-frame" # path to folder with frame-by-frame images, downloaded with extract_frame.py

frame_paths = [path + '/' + file for file in os.listdir(path) if file.lower().endswith('.jpg')]
predictions = np.empty((0,), dtype=int)

for frame in frame_paths:
    image = Image.open(frame).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        _, predicted_class = torch.max(output, 1)
        predicted_class_num = predicted_class.item()
        predictions = np.append(predictions, predicted_class_num)

print(predictions)

[1 1 3 ... 2 1 3]


Sliding window with majority vote to smooth prediction noise

In [5]:
def mode_filter(arr, window_size):
    results = np.empty_like(arr)
    for i in range(window_size // 2, len(arr) - window_size // 2):
            window = arr[i-window_size // 2:i+window_size // 2]
            mode_result = mode(window, keepdims=True)
            if mode_result.count[0] == 1:  # Tie
                    results[i] = arr[i]  # Use original
            else:
                    results[i] = mode_result.mode[0]
    return results

In [None]:
predictions = mode_filter(predictions, 7) # initial mode filter to clear up larger noise
predictions = mode_filter(predictions, 3) # mode filter to clear up individual frames that are misclassified

In [None]:
label_map = {0: "Distant or No Car", 1: "Front", 2: "Inside", 3: "Rear", 4: "Side"}
predictions_mapped = np.vectorize(lambda x: label_map.get(x, f'unknown: {x}'))(predictions)

# Display settings
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
color = (0, 255, 0)  # Green text
thickness = 2
position = (50, 50)  # Top-left corner

# Loop through all frames
for i in range(len(predictions_mapped)):
    image = cv2.imread(frame_paths[i])
    
    if image is None:
        print(f"⚠️ Warning: Could not read {frame_paths[i]}")
        continue

    label_text = f"Label: {predictions_mapped[i]}"
    cv2.putText(image, label_text, position, font, font_scale, color, thickness, cv2.LINE_AA)

    cv2.imshow('Labeled Video', image)

    key = cv2.waitKey(0)
    if key == ord('q'):
        cv2.destroyAllWindows()
        cv2.waitKey(1)
        break
    elif key != ord('n'):
        print("Press 'n' for next frame or 'q' to quit.")

cv2.destroyAllWindows()