In [29]:
import os

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
class MutableDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.train_transform = transforms.Compose(
            [
                transforms.RandomVerticalFlip(p=0.3),
                transforms.RandomHorizontalFlip(p=0.3),
                transforms.RandomRotation(degrees=72),
                transforms.ColorJitter(contrast=0.3, saturation=0.3, brightness=0.3),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ]
        )
        self.eval_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ]
        )

    def train(self):
        self.dataset.transform = self.train_transform  # type: ignore

    def eval(self):
        self.dataset.transform = self.eval_transform  # type: ignore

In [17]:
def create_dataloaders(train_dir, valid_dir, batch_size, **kwargs):
    train_dataset = datasets.ImageFolder(train_dir)
    valid_dataset = datasets.ImageFolder(valid_dir)
    if not train_dataset.classes == valid_dataset.classes:
        raise ValueError("train and valid dataset have different classes")

    class_names = train_dataset.classes
    class_to_id = train_dataset.class_to_idx

    train_loader = MutableDataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        **kwargs,
    )
    valid_loader = MutableDataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        **kwargs,
    )

    return (train_loader, valid_loader, {"class_names": class_names, "class_to_id": class_to_id})

In [33]:
train_path = "../data/processed/train/"
valid_path = "../data/processed/valid/"
num_workers = os.cpu_count()

train_loader, valid_loader, classes = create_dataloaders(
    train_path,
    valid_path,
    batch_size=32,
    num_workers=num_workers,  # type: ignore
    pin_memory=True,
    persistent_workers=True,
)