In [None]:
import os
from pathlib import Path
import cv2
import torch
import numpy as np
from tqdm import tqdm
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
from model import GestureRecognitionModel

In [None]:
class GestureDataset(Dataset):
    def __init__(self, root_dir, num_frames=20, transform=None, gesture_to_label=None):
        self.root_dir = Path(root_dir)
        self.num_frames = num_frames
        self.transform = transform
        self.samples = []
        self.gesture_to_label = gesture_to_label or {
            gesture.name: idx for idx, gesture in enumerate(sorted(self.root_dir.iterdir()))
        }

        for gesture in self.root_dir.iterdir():
            label = self.gesture_to_label[gesture.name]
            frames = sorted(list(gesture.glob("*.jpg")))
            if len(frames) >= num_frames:
                self.samples.append((frames[:num_frames], label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        frame_paths, label = self.samples[idx]
        frames = [self.transform(Image.open(p).convert("RGB")) for p in frame_paths]
        clip = torch.stack(frames)  # shape: [num_frames, 3, 224, 224]
        return clip, label

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

gesture_to_label = {'Fist': 0, 'Four': 1, 'Me': 2, 'One': 3, 'Small': 4}
train_dir = "data/frames/train"
val_dir = "data/frames/val"

train_dataset = GestureDataset(train_dir, num_frames=16, transform=train_transform, gesture_to_label=gesture_to_label)
val_dataset   = GestureDataset(val_dir, num_frames=4,  transform=val_transform, gesture_to_label=gesture_to_label)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = len(gesture_to_label)
model = GestureRecognitionModel(num_classes=num_classes).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=2)

In [None]:
import copy

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
no_improve_epochs = 0
patience = 5
num_epochs = 30

os.makedirs("checkpoints", exist_ok=True)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 30)

    for phase in ['train', 'val']:
        model.train() if phase == 'train' else model.eval()
        dataloader = train_loader if phase == 'train' else val_loader

        running_loss = 0.0
        all_preds, all_labels = [], []

        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_acc = accuracy_score(all_labels, all_preds)

        print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}")

        if phase == "val":
            scheduler.step(epoch_loss)
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), "gesture_model.pth")
                print("Validation accuracy improved. Model saved.")
                no_improve_epochs = 0
            else:
                no_improve_epochs += 1

    if no_improve_epochs >= patience:
        print("\nEarly stopping triggered.")
        break

print(f"\nBest Validation Accuracy: {best_acc:.4f}")