TODO:
- Double check class_mapping.json
- Create necessary hyperparameter constants
- Add references to README

In [None]:
import os
import json
import timm
import torch
import shutil
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from timm.data import resolve_model_data_config
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
# Define constants and hyperparameters
ORIGINAL_DATA_PATH = "Data/Raw/Images"
SPLIT_DATA_PATH = "Data/Processed"

VALID_SIZE = 0.15
TEST_SIZE = 0.15

NUM_CLASSES = 120

In [None]:
# Returns True if the file at the specified path is an uncorrupted image; False otherwise
def is_image(path):
    try:
        Image.open(path).verify()
        return True
    except Exception as e:
        return False

# Performs a stratified split of the data into training, validation, and test sets
def split_data(input_path, output_path, valid_size=VALID_SIZE, test_size=TEST_SIZE):
    # Clean up any existing data splits
    for dir in [output_path, "Train", "Valid", "Test"]:
        if os.path.exists(dir):
            shutil.rmtree(dir)

    # Create the Train, Valid, and Test directories
    for dir in ["Train", "Valid", "Test"]:
        os.makedirs(os.path.join(output_path, dir), exist_ok=True)

    # 
    breed_dirs = [d for d in os.listdir(input_path) if os.path.isdir(os.path.join(input_path, d))] 
    for breed in breed_dirs:
        # Get all image names for the current breed
        breed_path = os.path.join(input_path, breed)
        images = [f for f in os.listdir(breed_path) if is_image(os.path.join(breed_path, f))]

        # Split the images into (train + validation) and test sets
        train_valid_imgs, test_imgs = train_test_split(images, test_size=test_size, random_state=27) 

        # Further split the (train + validation) image set into train and validation sets
        train_imgs, valid_imgs = train_test_split(train_valid_imgs, test_size=valid_size/(1-test_size), random_state=27)

        # Copy images into their appropriate locations
        for set, img_list in zip(['Train', 'Valid', 'Test'], [train_imgs, valid_imgs, test_imgs]):
            set_breed_path = os.path.join(output_path, set, breed)
            os.makedirs(set_breed_path, exist_ok=True)
            for img in img_list:
                shutil.copy(os.path.join(breed_path, img), os.path.join(set_breed_path, img))

split_data(ORIGINAL_DATA_PATH, SPLIT_DATA_PATH, VALID_SIZE, TEST_SIZE)

In [None]:
# Load the EfficientNet-B3 model, pretrained on ImageNet-1K, and replace its head 
model = timm.create_model(
    "efficientnet_b3",
    pretrained=True,
    num_classes=NUM_CLASSES
)

# Retrieve model configurations to ensure that future inputs adhere to them 
data_config = resolve_model_data_config(model)
img_size = data_config["input_size"][-1]
mean = data_config["mean"]
std = data_config["std"]

In [None]:
# TODO declare BATCH_SIZE and NUM_WORKERS
def get_dataloaders(data_dir, img_size=img_size, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
    # TODO add augmentations to train_transforms 
    # Transforms
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)), 
        transforms.
        # Augment (flip, rotation, color jitter)
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    #     transforms.RandomHorizontalFlip(p=0.5),
    #     transforms.ColorJitter(
    #         brightness=0.1,
    #         contrast=0.1,
    #         saturation=0.1,
    #         hue=0.02
    #     )

    valid_test_transforms = transforms.Compose([
        transforms.Resize(int(img_size * 1.15)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    # Datasets 
    train_dataset = datasets.ImageFolder(
        root=os.path.join(data_dir, "Train"),
        transform=train_transforms
    )

    valid_dataset = datasets.ImageFolder(
        root=os.path.join(data_dir, "Valid"),
        transform=valid_test_transforms
    )

    test_dataset = datasets.ImageFolder(
        root=os.path.join(data_dir, "Test"),
        transform=valid_test_transforms
    )
    image_datasets = {"train":train_dataset, "valid":valid_dataset, "test":test_dataset}

    # TODO declare/update NUM_WORKERS according to machine specifications
    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    data_loaders = {"train":train_loader, "valid":valid_loader, "test":test_loader}

    return image_datasets, data_loaders

# TODO update this call with appropriate inputs
image_datasets, data_loaders = get_dataloaders(SPLIT_DATA_PATH, 30, img_size)

In [None]:
# Two phase train; first on full images and then on cropped ones

# Freeze all layers at first and then perform gradual unfreezing 


# Move the model, labels, and inputs to the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device); # Semicolon to suppress output 