In [1]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset, random_split
from torch import nn, optim
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Tuple, Dict

#### Initial configuration

In [2]:
filtered_data_dir = Path('data/filtered_images') 
batch_size = 32 
img_size = 224
learning_rate = 0.001
num_epochs = 15 
model_save_path = 'f1_track_layout_resnet18_v1.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


#### Setup image transformations

In [3]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

In [4]:
class CustomDataset(datasets.ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=transform)
        self.imgs = self.samples

    def __getitem__(self, index):
        path, target = self.imgs[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target


def get_dataloaders(data_dir, transforms_dict, batch_size=32, val_split=0.2):
    # Create the full dataset with transform=None initially
    full_dataset = CustomDataset(root=data_dir, transform=None)
    
    # Calculate sizes for train and validation splits
    val_size = int(len(full_dataset) * val_split)
    train_size = len(full_dataset) - val_size
    
    # Split the dataset
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Create datasets with appropriate transformations
    train_dataset.dataset.transform = transforms_dict['train']
    val_dataset.dataset.transform = transforms_dict['val']
    
    # Create and return dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader

In [5]:
class ImageClassificationDataModule:
    """
    Data module for image classification tasks.
    Handles data transforms, train/validation split, and DataLoader creation.
    """
    def __init__(
        self,
        data_dir: str,
        transform: Dict[str, transforms.Compose],
        batch_size: int = 32,
        val_split: float = 0.2,
        image_size: Tuple[int, int] = (224, 224)
    ):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.val_split = val_split
        self.image_size = image_size

        self.train_transforms = transform['train']
        self.val_transforms = transform['val']

        self.train_dataset = None
        self.val_dataset = None

        self.setup()

    def setup(self) -> None:
        full_dataset = datasets.ImageFolder(root=self.data_dir)
        total_samples = len(full_dataset)
        val_size = int(total_samples * self.val_split)
        train_size = total_samples - val_size

        generator = torch.Generator().manual_seed(42)
        train_subset, val_subset = random_split(
            full_dataset, [train_size, val_size], generator=generator
        )

        self.train_dataset = Subset(
            datasets.ImageFolder(root=self.data_dir, transform=self.train_transforms),
            train_subset.indices
        )
        self.val_dataset = Subset(
            datasets.ImageFolder(root=self.data_dir, transform=self.val_transforms),
            val_subset.indices
        )

    def train_dataloader(self) -> DataLoader:
        """Return DataLoader for training set."""
        if self.train_dataset is None:
            raise RuntimeError("Call setup() before train_dataloader().")
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader:
        """Return DataLoader for validation set."""
        if self.val_dataset is None:
            raise RuntimeError("Call setup() before val_dataloader().")
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
        )

In [None]:
data_module = ImageClassificationDataModule(
        data_dir='data/filtered_images',
        transform=data_transforms,
        batch_size=32,
        val_split=0.2,
        image_size=(224, 224)
    )
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

print(f"Train samples: {len(data_module.train_dataset)}")
print(f"Validation samples: {len(data_module.val_dataset)}")

# Iterate through one batch
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")

Train samples: 879
Validation samples: 219
Batch shape: torch.Size([32, 3, 224, 224])


In [10]:
class_to_idx = data_module.train_dataset.dataset.class_to_idx
print(class_to_idx)
idx_to_class = {v: k for k, v in class_to_idx.items()}
print(idx_to_class)

{'Albert Park Circuit': 0, 'Autódromo Hermanos Rodríguez': 1, 'Bahrain International Circuit': 2, 'Baku City Circuit': 3, 'Circuit Gilles Villeneuve': 4, 'Circuit Zandvoort': 5, 'Circuit de Barcelona-Catalunya': 6, 'Circuit de Monaco': 7, 'Circuit de Spa-Francorchamps': 8, 'Circuit of the Americas': 9, 'Hungaroring': 10, 'Imola (Autodromo Enzo e Dino Ferrari)': 11, 'Interlagos (Autódromo José Carlos Pace)': 12, 'Jeddah Corniche Circuit': 13, 'Las Vegas Street Circuit': 14, 'Lusail International Circuit': 15, 'Marina Bay Street Circuit': 16, 'Miami International Autodrome': 17, 'Monza (Autodromo Nazionale Monza)': 18, 'Red Bull Ring': 19, 'Shanghai International Circuit': 20, 'Silverstone Circuit': 21, 'Suzuka International Racing Course': 22, 'Yas Marina Circuit': 23}
{0: 'Albert Park Circuit', 1: 'Autódromo Hermanos Rodríguez', 2: 'Bahrain International Circuit', 3: 'Baku City Circuit', 4: 'Circuit Gilles Villeneuve', 5: 'Circuit Zandvoort', 6: 'Circuit de Barcelona-Catalunya', 7: 'Ci