# Task Classifier - Video Frame Classification

Train a ResNet50 classifier to identify tasks from video frames and process videos to extract task timestamps.

## Features
- Transfer learning with pretrained ResNet50
- Class imbalance handling with weighted loss
- Temporal smoothing for video inference
- CSV export of predictions and time ranges

## Task Classes
CameraTarget, ChickenThigh, CystModel, GloveCut, Idle, MovingIndividualAxes, RingRollercoaster, SeaSpikes, Suture

## Imports

In [None]:
import os

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
from fastai.vision.all import *

import matplotlib.pyplot as plt


## Configuration

In [None]:
# Paths
DATASET_PATH = "../../../datasets/tasks_classified/"
MODEL_DIR = "../../../processing/models/"

# Training
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 64
NUM_EPOCHS = 15
LEARNING_RATE = 0.0001
RANDOM_SEED = 42

# Parallel Processing
NUM_WORKERS = 8  # Parallel data loading (CPU cores to use)
USE_MIXED_PRECISION = True  # Faster training on GPU (FP16)

# Setup
os.makedirs(MODEL_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    
print(f"Device: {DEVICE}")
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Count: {torch.cuda.device_count()}")


## Create DataLoaders with FastAI

In [None]:
import numpy as np

# Get all image files and subsample to 1/5
all_files = get_image_files(DATASET_PATH)
np.random.seed(RANDOM_SEED)
subset_idx = np.random.choice(len(all_files), len(all_files) // 5, replace=False)
subset_files = [all_files[i] for i in sorted(subset_idx)]
subset_labels = [parent_label(f) for f in subset_files]

print(f"Using {len(subset_files)}/{len(all_files)} images (1/5 subset)")

dls = ImageDataLoaders.from_lists(
    DATASET_PATH,
    subset_files,
    subset_labels,
    valid_pct=0.15,
    seed=RANDOM_SEED,
    item_tfms=Resize(IMAGE_SIZE[0]),
    batch_tfms=[
        *aug_transforms(
            size=IMAGE_SIZE[0],
            flip_vert=False,
            max_rotate=10.0,
            max_lighting=0.2,
            max_warp=0.0,
            p_affine=0.5,
            p_lighting=0.5
        ),
        Normalize.from_stats(*imagenet_stats)
    ],
    bs=BATCH_SIZE,
    num_workers=NUM_WORKERS
)

# Get class names from the dataloader
CLASS_NAMES = list(dls.vocab)

print(f"Classes: {CLASS_NAMES}")
print(f"Training: {len(dls.train_ds)} images")
print(f"Validation: {len(dls.valid_ds)} images")
print(f"Using {NUM_WORKERS} worker threads for data loading")


## Dataset Samples

In [None]:
# Display sample images organised by task class
fig, axes = plt.subplots(len(CLASS_NAMES), 4, figsize=(10, 1.5 * len(CLASS_NAMES)))

for row, class_name in enumerate(CLASS_NAMES):
    class_path = Path(DATASET_PATH) / class_name
    sample_images = list(class_path.ls())[:4]
    
    for col, img_path in enumerate(sample_images):
        axes[row, col].imshow(PILImage.create(img_path))
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].text(-0.1, 0.5, class_name, 
                              transform=axes[row, col].transAxes,
                              fontsize=10, fontweight='bold', 
                              ha='right', va='center')

plt.subplots_adjust(left=0.15, right=1.0, hspace=0.02, wspace=0.02)
plt.show()

## Training

In [None]:
# Create vision learner with ResNet50
learn = vision_learner(
    dls, 
    resnet50, 
    metrics=[accuracy, error_rate],
    loss_func=CrossEntropyLossFlat()
)

# Enable mixed precision training for faster GPU training
if USE_MIXED_PRECISION and torch.cuda.is_available():
    learn = learn.to_fp16()
    print("Mixed precision training enabled (FP16)")

# Train!
print("Training with FastAI...")
print(f"Batch size: {BATCH_SIZE}")
print(f"Workers: {NUM_WORKERS}")
learn.fine_tune(NUM_EPOCHS, base_lr=LEARNING_RATE, freeze_epochs=3)


## Training Loss

In [None]:
learn.recorder.plot_loss()
plt.title('Training Loss Over Time')
plt.tight_layout()
plt.show()


## Resume Training (Optional)

In [None]:
# REMAINING_EPOCHS = 11  # Adjust based on how many epochs you've completed

# print(f"Resuming training for {REMAINING_EPOCHS} more epochs...")
# learn.fit_one_cycle(REMAINING_EPOCHS, lr_max=LEARNING_RATE)
# print("Training completed")
learn.path = Path(".")
learn.export(os.path.join(MODEL_DIR,
"task_classifier.pkl"))


## Save Model

In [None]:
fastai_model_path = os.path.join(MODEL_DIR, "task_classifier.pkl")
learn.export(fastai_model_path)
print(f"Model exported: {fastai_model_path}")


## Evaluate on Validation Set

In [None]:
# Get validaiton predictions using FastAI learner
preds, targets = learn.get_preds(dl=dls.valid)
pred_labels = preds.argmax(dim=1).numpy()
true_labels = targets.numpy()

val_acc = accuracy_score(true_labels, pred_labels)
print(f"Validation Accuracy: {100*val_acc:.2f}%")


In [None]:
print(classification_report(true_labels, pred_labels, target_names=CLASS_NAMES))


In [None]:
# Generate and plot confusion matrix using fastai
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(10, 8), dpi=80)
plt.tight_layout()
plt.show()
