In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import cv2
import pandas as pd
import os
from sklearn.model_selection import train_test_split
import albumentations
import matplotlib.pyplot as plt
import random

##### Loading Data

In [3]:
INPUT_SIZE = 224 #TODO: change this depending on what model we are using
BATCH_SIZE = 32
NUM_CLASSES = 5
EPOCHS = 10
LEARNING_RATE = 0.001
MODEL_SAVE_PATH = 'best_model.pth'

# Create a class for the hand gestures
class GestureDataset(Dataset):
    def __init__(self, annotations, root_dir, transform=None, augmentations=None):
        self.annotations = annotations
        self.root_dir = root_dir
        self.transform = transform
        self.augmentations = augmentations

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # From the CSV, read in the row for the current image
        row = self.annotations.iloc[idx]
        img_path = os.path.join(self.root_dir, row['image_path'])
        image = cv2.imread(img_path)

        # Get the coordinates of the bounding box from the annotations.csv
        x_start, y_start, x_end, y_end = row['x_start'], row['y_start'], row['x_end'], row['y_end']
        hand_crop = image[int(y_start):int(y_end), int(x_start):int(x_end)]

        hand_crop = cv2.cvtColor(hand_crop, cv2.COLOR_BGR2RGB)

        # Augment the training data
        if self.augmentations:
            hand_crop = self.augmentations(image=hand_crop)['image']

        if self.transform:
            hand_crop = self.transform(hand_crop)

        label = row['class_label']  # Gesture class
        return hand_crop, label


# Augmentations for training
augmentations = albumentations.Compose([
    albumentations.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), # Changes brightness and constrast
    albumentations.Rotate(limit=15, p=0.5), # Small rotations
    albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.5), # Random scaling and shifts
    albumentations.GaussianBlur(blur_limit=(3, 5), p=0.3),  # Add blur
    albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.4) # Modify hue and saturation
])

# Data transformations
transform = transforms.Compose([
    transforms.ToPILImage(), 
    transforms.Resize((INPUT_SIZE, INPUT_SIZE)), # Resize to the input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # TODO: change this depending on the model
])

# Load dataset
annotations = pd.read_csv('annotations.csv')

# Split into train, validation, and test sets. 80%, 10%, 10% split
train_annotations, temp_annotations = train_test_split(annotations, test_size=0.2, random_state=42)
val_annotations, test_annotations = train_test_split(temp_annotations, test_size=0.5, random_state=42)

train_dataset = GestureDataset(annotations=train_annotations, root_dir='images/', transform=transform, augmentations=augmentations) # augment training data only
val_dataset = GestureDataset(annotations=val_annotations, root_dir='images/', transform=transform)
test_dataset = GestureDataset(annotations=test_annotations, root_dir='images/', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


##### Mobile Net

In [4]:
from torchvision import models
import torch.nn as nn

In [5]:
model = models.mobilenet_v2(pretrained=True)

model.classifier[1] = nn.Linear(model.last_channel, NUM_CLASSES)  

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /Users/katielee/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100.0%


In [6]:
model = models.mobilenet_v3_large(pretrained=True)

model.classifier[3] = nn.Linear(model.classifier[0].out_features, NUM_CLASSES)  # NUM_CLASSES = 5

import torch.optim as optim

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /Users/katielee/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100.0%


In [10]:
class GestureDataset(Dataset):
    def __init__(self, annotations, root_dir, transform=None, augmentations=None):
        self.annotations = annotations
        self.root_dir = root_dir
        self.transform = transform
        self.augmentations = augmentations

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        row = self.annotations.iloc[idx]
        img_path = os.path.join(self.root_dir, row['image_path'])
        image = cv2.imread(img_path)

        if image is None:  
            print(f"Skipping unreadable image: {img_path}")
            raise ValueError(f"Unreadable image at index {idx}: {img_path}")

        x_start, y_start, x_end, y_end = row['x_start'], row['y_start'], row['x_end'], row['y_end']
        hand_crop = image[int(y_start):int(y_end), int(x_start):int(x_end)]

        hand_crop = cv2.cvtColor(hand_crop, cv2.COLOR_BGR2RGB)

        if self.augmentations:
            hand_crop = self.augmentations(image=hand_crop)['image']

        if self.transform:
            hand_crop = self.transform(hand_crop)

        label = row['class_label']
        return hand_crop, label

In [16]:
def filter_invalid_annotations(dataset):
    valid_annotations = []
    for idx in range(len(dataset.annotations)):
        row = dataset.annotations.iloc[idx]
        img_path = os.path.join(dataset.root_dir, row['image_path'])

        image = cv2.imread(img_path)
        if image is None:
            print(f"Skipping unreadable image: {img_path}")
            continue

        valid_annotations.append(row)

    dataset.annotations = pd.DataFrame(valid_annotations).reset_index(drop=True)
    return dataset

In [17]:
train_dataset = filter_invalid_annotations(train_dataset)
val_dataset = filter_invalid_annotations(val_dataset)
test_dataset = filter_invalid_annotations(test_dataset)

Skipping unreadable image: images/images/left_shoot/IMG_3044.jpg
Skipping unreadable image: images/images/left/left.jpg
Skipping unreadable image: images/images/left_shoot/Screenshot 2024-11-20 at 10.46.05 PM.png
Skipping unreadable image: images/images/left/IMG_3013.jpg
Skipping unreadable image: images/images/left/IMG_3043.jpg
Skipping unreadable image: images/images/right_shoot/15396C2E-A537-4EE2-B96D-58485BE10491_4_5005_c.jpeg
Skipping unreadable image: images/images/left_shoot/IMG_5046.JPG
Skipping unreadable image: images/images/right_shoot/IMG_4909.JPG
Skipping unreadable image: images/images/right_shoot/IMG_3078.jpg
Skipping unreadable image: images/images/left/IMG_2959.jpg
Skipping unreadable image: images/images/right/IMG_1320.jpg
Skipping unreadable image: images/images/left_shoot/IMG_2975.jpg
Skipping unreadable image: images/images/right_shoot/IMG_2911.jpg
Skipping unreadable image: images/images/right/IMG_2943.jpg
Skipping unreadable image: images/images/left/IMG_3096.jpg

[ WARN:0@405.162] global loadsave.cpp:241 findDecoder imread_('images/images/left_shoot/IMG_3044.jpg'): can't open/read file: check file path/integrity
[ WARN:0@405.163] global loadsave.cpp:241 findDecoder imread_('images/images/left/left.jpg'): can't open/read file: check file path/integrity
[ WARN:0@405.163] global loadsave.cpp:241 findDecoder imread_('images/images/left_shoot/Screenshot 2024-11-20 at 10.46.05 PM.png'): can't open/read file: check file path/integrity
[ WARN:0@405.163] global loadsave.cpp:241 findDecoder imread_('images/images/left/IMG_3013.jpg'): can't open/read file: check file path/integrity
[ WARN:0@405.163] global loadsave.cpp:241 findDecoder imread_('images/images/left/IMG_3043.jpg'): can't open/read file: check file path/integrity
[ WARN:0@405.164] global loadsave.cpp:241 findDecoder imread_('images/images/right_shoot/15396C2E-A537-4EE2-B96D-58485BE10491_4_5005_c.jpeg'): can't open/read file: check file path/integrity
[ WARN:0@405.164] global loadsave.cpp:241 f

In [18]:
for epoch in range(EPOCHS):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # Move data to GPU if available

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {running_loss / len(train_loader)}")

    model.eval() 
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Validation Loss: {val_loss / len(val_loader)}, Accuracy: {100 * correct / total}%")

torch.save(model.state_dict(), MODEL_SAVE_PATH)

ZeroDivisionError: integer division or modulo by zero