# Gesture Recognition CNN Fine-tuning on Colab

## 1. Mount Google Drive
This cell mounts your Google Drive to the Colab environment, allowing access to datasets and saving models. You'll be prompted for authorization.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Imports and Setup

In [None]:
import glob
import os.path as osp
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os 
import pandas as pd 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import models, transforms

# setting random number seed.
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

## 3. Configuration
Set your training parameters here. These were previously handled by `argparse`.

In [None]:
# --- Your Configuration ---
# TODO: Adjust MODEL_NAME and other parameters as needed.
MODEL_NAME = "vgg16"  # Choices: "vgg16", "vgg19", "mobilenet", "mobilenet_v2"
NUM_CLASSES = 14
BATCH_SIZE = 32
NUM_EPOCHS = 10 # Adjust as needed for Colab training times
LEARNING_RATE = 1e-3

# In the configuration section (cell 3):
DRIVE_PROJECT_ROOT = "/content/drive/MyDrive"
DATASET_ROOT_PATH = os.path.join(DRIVE_PROJECT_ROOT, "data")


# Path for saving checkpoints and metrics
CHECKPOINTS_BASE_PATH = osp.join(DRIVE_PROJECT_ROOT, "checkpoints") 

# Path for saving the final model
FINAL_MODEL_SAVE_PATH_BASE = DRIVE_PROJECT_ROOT

# Example image path (update if your example image is located elsewhere or named differently)
# Make sure this example image exists if you want the visualization to run.
EXAMPLE_IMAGE_PATH = osp.join(DATASET_ROOT_PATH, "train/Gesture_0/example_gesture.jpg") # Adjust if needed

print(f"Using Model: {MODEL_NAME}")
print(f"Dataset root path: {DATASET_ROOT_PATH}")
print(f"Checkpoints will be saved in subfolders of: {CHECKPOINTS_BASE_PATH}")
print(f"Final model will be saved in: {FINAL_MODEL_SAVE_PATH_BASE}")
print(f"Example image path: {EXAMPLE_IMAGE_PATH}")

## 4. CUDA Availability Check

In [None]:
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    if torch.cuda.device_count() > 0:
        print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

## 5. Image Preprocessing Class

In [None]:
class ImageTransform:
    def __init__(self, size, mean, std):
        self.data_transform = {
            "train": transforms.Compose(
                [
                    transforms.RandomResizedCrop(size, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),
            "val": transforms.Compose(
                [
                    transforms.Resize(size),
                    transforms.CenterCrop(size),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std),
                ]
            ),
        }

    def __call__(self, img, phase="train"):
        return self.data_transform[phase](img)

## 6. Example Image Visualization (Optional)
This cell visualizes an example image and its transformed version.
Ensure `EXAMPLE_IMAGE_PATH` is set correctly in the configuration cell.

In [None]:
image_file_path = EXAMPLE_IMAGE_PATH 

try:
    img_originalsize = Image.open(image_file_path)
    img_display = img_originalsize.resize((256, 256)) 
    img_display = img_display.convert("RGB") 
    plt.imshow(img_display)
    plt.title("Example Gesture Image (Original-like)")
    plt.show()

    size = 256
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    transform = ImageTransform(size, mean, std)

    img_to_transform = Image.open(image_file_path).convert("RGB")
    img_transformed_display = transform(img_to_transform, phase="train")
    print(f"Transformed image shape: {img_transformed_display.shape}")

    img_transformed_display_np = img_transformed_display.numpy().transpose((1, 2, 0))
    img_transformed_display_np = np.clip(img_transformed_display_np, 0, 1)
    plt.imshow(img_transformed_display_np)
    plt.title("Example Gesture Image (Transformed)")
    plt.show()
except FileNotFoundError:
    print(
        f"Warning: Example image for display not found at {image_file_path}. Skipping display."
    )
except Exception as e:
    print(f"An error occurred during example image display: {e}")

## 7. Data Path List and Dataset Class

In [None]:
def make_datapath_list(phase="train"):
    # Uses DATASET_ROOT_PATH from the configuration cell
    rootpath = DATASET_ROOT_PATH 
    target_path = osp.join(
        rootpath, phase, "Gesture_*", "*.jpg"
    ) 
    print(f"Searching for images in: {target_path}")
    path_list = glob.glob(target_path, recursive=False)
    if not path_list:
        print(
            f"Warning: No images found for phase '{phase}' with pattern '{target_path}'. Check your dataset structure and path."
        )
    return path_list

train_list = make_datapath_list(phase="train")
print(f"Number of training gesture images: {len(train_list)}")

val_list = make_datapath_list(phase="val")
print(f"Number of validation gesture images: {len(val_list)}")


class GestureDataset(data.Dataset):
    def __init__(self, file_list, transform=None, phase="train"):
        self.file_list = file_list
        self.transform = transform
        self.phase = phase

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path).convert("RGB")
        img_transformed = self.transform(img, self.phase)
        label_name = img_path.split(osp.sep)[-2]
        try:
            label = int(label_name.replace("Gesture_", ""))
        except ValueError:
            print(
                f"Error parsing label from folder name: {label_name} in path {img_path}"
            )
            label = -1 
        return img_transformed, label

## 8. Create Datasets and DataLoaders

In [None]:
size = 256
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

train_dataset = GestureDataset(
    file_list=train_list, transform=ImageTransform(size, mean, std), phase="train"
)
val_dataset = GestureDataset(
    file_list=val_list, transform=ImageTransform(size, mean, std), phase="val"
)

if len(train_dataset) == 0:
    raise ValueError(
        "Training dataset is empty. Please check your 'data/train' folder (inside DRIVE_PROJECT_ROOT) and `make_datapath_list` function."
    )
if len(val_dataset) == 0:
    print(
        "Warning: Validation dataset is empty. Training will proceed without validation if this is intended."
    )

# Use BATCH_SIZE from configuration
# Consider reducing num_workers if you encounter issues in Colab (e.g., to 2)
num_workers_colab = 2 # Often 2 is a safe bet for Colab free tier
print(f"Using num_workers: {num_workers_colab} for DataLoaders")

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers_colab, pin_memory=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_colab, pin_memory=True
)
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

## 9. Model Definition and Fine-tuning Setup

In [None]:
use_pretrained = True

if MODEL_NAME == "vgg16":
    net = models.vgg16(
        weights=models.VGG16_Weights.IMAGENET1K_V1 if use_pretrained else None
    )
    net.classifier[6] = nn.Linear(in_features=4096, out_features=NUM_CLASSES)
elif MODEL_NAME == "vgg19":
    net = models.vgg19(
        weights=models.VGG19_Weights.IMAGENET1K_V1 if use_pretrained else None
    )
    net.classifier[6] = nn.Linear(in_features=4096, out_features=NUM_CLASSES)
elif MODEL_NAME == "mobilenet": # Note: Original script used mobilenet_v3_large
    net = models.mobilenet_v3_large(
        weights=(
            models.MobileNet_V3_Large_Weights.IMAGENET1K_V1 if use_pretrained else None
        )
    )
    net.classifier[3] = nn.Linear(
        in_features=net.classifier[3].in_features, out_features=NUM_CLASSES
    )
elif MODEL_NAME == "mobilenet_v2":
    net = models.mobilenet_v2(
        weights=models.MobileNet_V2_Weights.IMAGENET1K_V1 if use_pretrained else None
    )
    if hasattr(net.classifier, "1") and isinstance(net.classifier[1], nn.Linear):
        net.classifier[1] = nn.Linear(net.classifier[1].in_features, NUM_CLASSES)
    elif isinstance(net.classifier, nn.Sequential) and isinstance(net.classifier[-1], nn.Linear):
        last_layer_in_features = net.classifier[-1].in_features
        net.classifier[-1] = nn.Linear(last_layer_in_features, NUM_CLASSES)
    elif isinstance(net.classifier, nn.Linear):
         net.classifier = nn.Linear(net.classifier.in_features, NUM_CLASSES)
    else:
        print("Warning: MobileNetV2 classifier structure not standard. Attempting replacement using net.last_channel.")
        if hasattr(net, "last_channel"):
            net.classifier = nn.Linear(net.last_channel, NUM_CLASSES)
        else:
            raise AttributeError(
                "Cannot automatically determine input features for MobileNetV2 classifier. Please check model structure."
            )
else:
    raise ValueError(f"Model {MODEL_NAME} not supported.")

print(f"Using model: {MODEL_NAME}")
# net.train() # This will be set in the training loop

criterion = nn.CrossEntropyLoss()

params_to_update = []

if MODEL_NAME in ["vgg16", "vgg19"]:
    params_to_update_1 = []
    params_to_update_2 = []
    params_to_update_3 = []

    update_param_names_1 = ["features"]
    update_param_names_2 = [
        "classifier.0.weight", "classifier.0.bias",
        "classifier.3.weight", "classifier.3.bias",
    ]
    update_param_names_3 = ["classifier.6.weight", "classifier.6.bias"]

    for name, param in net.named_parameters():
        param.requires_grad = False
        if update_param_names_1[0] in name: # Fine-tune deeper layers less
            param.requires_grad = True
            params_to_update_1.append(param)
        elif name in update_param_names_2:
            param.requires_grad = True
            params_to_update_2.append(param)
        elif name in update_param_names_3: # Train classifier head more
            param.requires_grad = True
            params_to_update_3.append(param)
    
    optimizer = optim.SGD([
        {'params': params_to_update_1, 'lr': LEARNING_RATE / 10}, # Slower LR for features
        {'params': params_to_update_2, 'lr': LEARNING_RATE / 2},  # Medium LR 
        {'params': params_to_update_3, 'lr': LEARNING_RATE}      # Higher LR for new classifier
    ], momentum=0.9)

else:  # For mobilenet, mobilenet_v2 - fine-tune only the new classifier
    for param in net.parameters():
        param.requires_grad = False

    if MODEL_NAME == "mobilenet": # mobilenet_v3_large
        for param in net.classifier[3].parameters():
            param.requires_grad = True
            params_to_update.append(param)
    elif MODEL_NAME == "mobilenet_v2":
        final_classifier_layer = net.classifier[-1] if isinstance(net.classifier, nn.Sequential) else net.classifier
        for param in final_classifier_layer.parameters():
            param.requires_grad = True
            params_to_update.append(param)
    
    print(f"Optimizing {len(params_to_update)} parameters for {MODEL_NAME} (classifier only).")
    optimizer = optim.SGD(params_to_update, lr=LEARNING_RATE, momentum=0.9)

## 10. Training Function

In [None]:
def train_model(net, model_name_arg, dataloaders_dict_arg, criterion_arg, optimizer_arg, num_epochs_arg):
    train_accuracy_list = []
    train_loss_list = []
    val_accuracy_list = []
    val_loss_list = []
    metrics_history = []

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Training on device: {device}")
    if torch.cuda.is_available() and torch.cuda.device_count() > 0:
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    net.to(device)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True

    # Use CHECKPOINTS_BASE_PATH from configuration for saving
    model_save_dir = osp.join(CHECKPOINTS_BASE_PATH, f"checkpoints_{model_name_arg}")
    if not osp.exists(model_save_dir):
        os.makedirs(model_save_dir)
        print(f"Created directory for model checkpoints: {model_save_dir}")

    metrics_file_path = osp.join(model_save_dir, f"training_metrics_{model_name_arg}.csv")

    for epoch in range(num_epochs_arg):
        print(f"Epoch {epoch + 1}/{num_epochs_arg}")
        print("-------------")
        epoch_metrics = {"epoch": epoch + 1}

        for phase in ["train", "val"]:
            if phase == "train":
                net.train()
            else:
                if not dataloaders_dict_arg["val"].dataset or len(dataloaders_dict_arg["val"].dataset) == 0:
                    print("Validation dataset is empty or not found, skipping validation phase.")
                    epoch_metrics["val_loss"] = None
                    epoch_metrics["val_acc"] = None
                    if phase == "val" and "val" not in val_accuracy_list: # Ensure lists are extendable
                         val_loss_list.append(float('nan'))
                         val_accuracy_list.append(float('nan'))
                    continue
                net.eval()

            epoch_loss = 0.0
            epoch_corrects = 0

            if not dataloaders_dict_arg[phase].dataset or len(dataloaders_dict_arg[phase].dataset) == 0:
                print(f"Dataset for phase '{phase}' is empty, skipping.")
                if phase == "train":
                    epoch_metrics["train_loss"] = None
                    epoch_metrics["train_acc"] = None
                    train_loss_list.append(float('nan'))
                    train_accuracy_list.append(float('nan'))
                elif phase == "val": # This case is covered above, but for safety
                    epoch_metrics["val_loss"] = None
                    epoch_metrics["val_acc"] = None
                    val_loss_list.append(float('nan'))
                    val_accuracy_list.append(float('nan'))
                continue
            
            current_dataloader = dataloaders_dict_arg[phase]
            current_dataset_size = len(current_dataloader.dataset)

            for i, (inputs, labels) in enumerate(current_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer_arg.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    outputs = net(inputs)
                    loss = criterion_arg(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase == "train":
                        loss.backward()
                        optimizer_arg.step()
                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)

            epoch_loss = epoch_loss / current_dataset_size
            epoch_acc = epoch_corrects.double() / current_dataset_size
            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase == "train":
                train_loss_list.append(epoch_loss)
                train_accuracy_list.append(epoch_acc.cpu().item())
                epoch_metrics["train_loss"] = epoch_loss
                epoch_metrics["train_acc"] = epoch_acc.cpu().item()
            elif phase == "val":
                val_loss_list.append(epoch_loss)
                val_accuracy_list.append(epoch_acc.cpu().item())
                epoch_metrics["val_loss"] = epoch_loss
                epoch_metrics["val_acc"] = epoch_acc.cpu().item()
        
        checkpoint_name = f"gesture_{model_name_arg}_epoch_{epoch+1}.pth"
        checkpoint_path = osp.join(model_save_dir, checkpoint_name)
        torch.save(net.state_dict(), checkpoint_path)
        print(f"Saved model checkpoint to {checkpoint_path}")

        metrics_history.append(epoch_metrics)
        try:
            df_metrics = pd.DataFrame(metrics_history)
            df_metrics.to_csv(metrics_file_path, index=False)
            print(f"Updated metrics saved to {metrics_file_path}")
        except Exception as e:
            print(f"Could not save metrics to CSV: {e}")


    return train_accuracy_list, train_loss_list, val_accuracy_list, val_loss_list

## 11. Start Training

In [None]:
# Using configuration variables
train_acc_list, train_loss_list, val_acc_list, val_loss_list = train_model(
    net,
    MODEL_NAME,
    dataloaders_dict,
    criterion,
    optimizer,
    num_epochs_arg=NUM_EPOCHS
)

## 12. Save Final Model

In [None]:
# Use FINAL_MODEL_SAVE_PATH_BASE from configuration
final_save_path = osp.join(FINAL_MODEL_SAVE_PATH_BASE, f"gesture_{MODEL_NAME}_finetuned_final.pth")
try:
    torch.save(net.state_dict(), final_save_path)
    print(f"Final model saved to {final_save_path}")
except Exception as e:
    print(f"Error saving final model: {e}")

## 13. Plotting Results

In [None]:
# Check if val_acc_list and val_loss_list have valid (non-NaN) data before plotting
valid_val_acc = [x for x in val_acc_list if not np.isnan(x)]
valid_val_loss = [x for x in val_loss_list if not np.isnan(x)]

# Ensure train lists also have data
valid_train_acc = [x for x in train_acc_list if not np.isnan(x)]
valid_train_loss = [x for x in train_loss_list if not np.isnan(x)]


if valid_val_acc and valid_val_loss and valid_train_acc and valid_train_loss:
    # Determine the number of epochs plotted based on the shortest list that's not empty
    # This handles cases where validation might have been skipped for some epochs or altogether.
    num_epochs_plotted = min(len(valid_train_acc) if valid_train_acc else float('inf'), 
                             len(valid_train_loss) if valid_train_loss else float('inf'),
                             len(valid_val_acc) if valid_val_acc else float('inf'),
                             len(valid_val_loss) if valid_val_loss else float('inf'))

    if num_epochs_plotted == float('inf') or num_epochs_plotted == 0 :
        print("Not enough data to plot results.")
    else:
        epoch_plot_range = list(range(1, num_epochs_plotted + 1))
        
        fig, ax = plt.subplots(facecolor="w", figsize=(12, 6))
        
        # Plot training data up to num_epochs_plotted
        if valid_train_acc: ax.plot(epoch_plot_range, train_acc_list[:num_epochs_plotted], label="Training Accuracy", marker='o')
        if valid_train_loss: ax.plot(epoch_plot_range, train_loss_list[:num_epochs_plotted], label="Training Loss", marker='o')
        
        # Plot validation data up to num_epochs_plotted
        if valid_val_acc: ax.plot(epoch_plot_range, val_acc_list[:num_epochs_plotted], label="Validation Accuracy", marker='x')
        if valid_val_loss: ax.plot(epoch_plot_range, val_loss_list[:num_epochs_plotted], label="Validation Loss", marker='x')
        
        if epoch_plot_range:
            plt.xticks(epoch_plot_range)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Value")
        ax.set_title(f"Training and Validation Metrics for {MODEL_NAME} ({num_epochs_plotted} Epochs)")
        ax.legend()
        plt.grid(True)
        
        # Save the plot to Drive
        plot_save_path = osp.join(CHECKPOINTS_BASE_PATH, f"checkpoints_{MODEL_NAME}", f"training_plot_{MODEL_NAME}.png")
        try:
            plt.savefig(plot_save_path)
            print(f"Plot saved to {plot_save_path}")
        except Exception as e:
            print(f"Could not save plot: {e}")
        plt.show()

elif valid_train_acc and valid_train_loss:
    # Only training data is available
    num_epochs_plotted = min(len(valid_train_acc), len(valid_train_loss))
    if num_epochs_plotted > 0:
        epoch_plot_range = list(range(1, num_epochs_plotted + 1))
        fig, ax = plt.subplots(facecolor="w", figsize=(12, 6))
        ax.plot(epoch_plot_range, train_acc_list[:num_epochs_plotted], label="Training Accuracy", marker='o')
        ax.plot(epoch_plot_range, train_loss_list[:num_epochs_plotted], label="Training Loss", marker='o')
        if epoch_plot_range:
            plt.xticks(epoch_plot_range)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Value")
        ax.set_title(f"Training Metrics for {MODEL_NAME} ({num_epochs_plotted} Epochs)")
        ax.legend()
        plt.grid(True)
        
        plot_save_path = osp.join(CHECKPOINTS_BASE_PATH, f"checkpoints_{MODEL_NAME}", f"training_plot_{MODEL_NAME}.png")
        try:
            plt.savefig(plot_save_path)
            print(f"Plot saved to {plot_save_path}")
        except Exception as e:
            print(f"Could not save plot: {e}")
        plt.show()
    else:
        print("No valid training data to plot.")
else:
    print("No validation data to plot or validation was skipped. Only training data might be available if training occurred.")

---
End of Notebook. Remember to adjust paths in the "Configuration" cell.