TODO:
- Double check class_mapping.json
- Create necessary hyperparameter constants

In [None]:
import os
import json
import timm
import torch
import shutil
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split

In [None]:
# Define constants and hyperparameters
RAW_DATA_PATH = "Data/Raw/Images"
PROCESSED_DATA_PATH = "Data/Processed"

VALID_SIZE = 0.15
TEST_SIZE = 0.15

NUM_CLASSES = 120

In [None]:
# Returns True if the file as the specified path is an uncorrupted image; False otherwise
def is_image(path):
    try:
        Image.open(path).verify()
        return True
    except Exception as e:
        return False

# Performs a stratified split of the data into training, validation, and test sets
def split_data(input_path, output_path, valid_size, test_size):
    # Clean up any existing data splits
    for dir in [output_path, "Train", "Valid", "Test"]:
        if os.path.exists(dir):
            shutil.rmtree(dir)

    # Create the Train, Valid, and Test directories
    for dir in ["Train", "Valid", "Test"]:
        os.makedirs(os.path.join(output_path, dir), exist_ok=True)

    # 
    breed_dirs = [d for d in os.listdir(input_path) if os.path.isdir(os.path.join(input_path, d))] 
    for breed in breed_dirs:
        # Get all image names for the current breed
        breed_path = os.path.join(input_path, breed)
        images = [f for f in os.listdir(breed_path) if is_image(os.path.join(breed_path, f))]

        # Split the images into (train + validation) and test sets
        train_valid_imgs, test_imgs = train_test_split(images, test_size=test_size, random_state=27) 

        # Further split the (train + validation) image set into train and validation sets
        train_imgs, valid_imgs = train_test_split(train_valid_imgs, test_size=valid_size/(1-test_size), random_state=27)

        # Copy images into their appropriate locations
        for set, img_list in zip(['Train', 'Valid', 'Test'], [train_imgs, valid_imgs, test_imgs]):
            set_breed_path = os.path.join(output_path, set, breed)
            os.makedirs(set_breed_path, exist_ok=True)
            for img in img_list:
                shutil.copy(os.path.join(breed_path, img), os.path.join(set_breed_path, img))

split_data(RAW_DATA_PATH, PROCESSED_DATA_PATH, VALID_SIZE, TEST_SIZE)

In [None]:
# Create transforms to be applied dynamically
    # Resize (TODO: look up ideal input to model)
    # Normalize pixel values (TODO: look up ideal range and/or mean and std)
    # For training, include data augmentationsâ€“including flips, to increase generalizability

# Create data set definitions

# Create data loaders

In [None]:
# Load the model
base_model = timm.create_model(
    "efficientnet_b3",
    pretrained=True,
    num_classes=NUM_CLASSES
)

# Freeze the layers 

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
base_model.to(device)

In [None]:
# Two phase train; first on full images and then on cropped ones