In [16]:
import os
import scipy
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

In [10]:
class OxfordFlowersDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(self.root_dir, "jpg")
        self.labels = self.load_labels()
        self.transform = transform
        self.error_logs = []

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        for attempt in range(len(self)):
            try:
                image = retrieve_images(idx)
                label = self.labels[idx]
                if self.transform:
                    image = self.transform(image)
                return image, label

            except Exception as e:
                self.log_errors(idx, e)
                idx = (idx + 1) % len(self)


    def retrieve_images(self, idx):
        image_name = f"image_{idx+1:05d}.jpg"
        image_path = os.path.join(self.image_dir, image_name)

        with Image.open(image_path) as img:
            img.verify()

        image = Image.open(image_path)
        image.load()

        if image.size[0] < 64 or image.size[1] < 64:
            raise ValueError("Image size is too small.")

        if image.mode != "RGB":
            image.convert("RGB")

        return image

    def load_labels(self):
        loadlabels = scipy.io.loadmat(os.path.join(root_dir, "imagelabels.mat"))
        labels = loadlabels["labels"][0] - 1
        return labels

    def log_error(self, idx, e):
        self.error_logs.append(
            {
                "Index": idx,
                "Error": e,
                "path": img_path if "img_path" in locals() else "unknown",
            }
        )
        print(f"Warning: Skipping corrupted image {idx}: {e}")

    def get_error_summary():
        if not self.error_logs:
            print(f"No errors encountered - dataset is clean.")

        else:
            print(f"\nEncountered {len(self.error_logs)} problematic images:")

In [12]:
root_dir = r"D:\WorkSpace\Machine Learning\PyTorch-for-Deep-Learning\C1-PyTorch Fundamentals\Module3\LAB1\flower_data"
dataset = OxfordFlowersDataset(root_dir=root_dir, transform=None)

In [13]:
# Split dataset
def get_train_val_split(dataset, val_frac=0.15, test_frac=0.15):
    val = int(len(dataset) * val_frac)
    test = int(len(dataset) * test_frac)
    train = len(dataset) - val - test
    train_dataset, val_dataset, test_dataset=random_split(
        dataset, [train, val, test]
    )

    return train_dataset, val_dataset, test_dataset

In [15]:
train_dataset, val_dataset, test_dataset = get_train_val_split(dataset)
print(f"Length of train dataset: {len(train_dataset)}")
print(f"Length of validation dataset: {len(val_dataset)}")
print(f"Length of test dataset: {len(test_dataset)}")

Length of train dataset: 5733
Length of validation dataset: 1228
Length of test dataset: 1228


In [25]:
def get_transformation(mean=None,std=None):
    transform = [
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean,std=std)
    ]

    augmentation_transform = [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2)
    ]

    return transform, augmentation_transform

In [26]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform, augmentation_transform = get_transformation(mean=mean, std=std)
train_transform = transforms.Compose(transform + augmentation_transform)

In [None]:
class SubsetWithTransform(Dataset):

    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]
        if self.transform:
            image = self.transform(image)

        return image, label

In [29]:
train_dataset = SubsetWithTransform(train_dataset, train_transform)
val_dataset = SubsetWithTransform(val_dataset, transform)
test_dataset = SubsetWithTransform(test_dataset, transform)

In [30]:
train_dataset, val_dataset, test_dataset = get_train_val_split(dataset)
print(f"Length of train dataset: {len(train_dataset)}")
print(f"Length of validation dataset: {len(val_dataset)}")
print(f"Length of test dataset: {len(test_dataset)}")

Length of train dataset: 5733
Length of validation dataset: 1228
Length of test dataset: 1228
