In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
from utils import display_image_grid

from torchvision import transforms
from torchvision.datasets import CIFAR10, ImageFolder
from torch.utils.data import DataLoader, Dataset, random_split

In [None]:
train_dataset = CIFAR10(
    root="./data", 
    train=True, 
    download=True, 
    transform=transforms.ToTensor()
)
test_dataset = CIFAR10(
    root="./data", 
    train=False, 
    download=True, 
    transform=transforms.ToTensor()
)

In [None]:
img, label = train_dataset[0] # get the first image

plt.imshow(img.numpy().transpose((1, 2, 0)))
plt.title(train_dataset.classes[label])
plt.axis("off")

In [None]:
train_loader = DataLoader(
    dataset=train_dataset, batch_size=64, shuffle=True
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=64, shuffle=False
)

In [None]:
for image, label in train_loader:
    print(f"Image batch shape: {image.size()}")
    print(f"Label batch shape: {label.size()}")
    break

In [None]:
images, labels = next(iter(train_loader))
display_image_grid(images, labels, train_dataset.classes)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.samples = []
        self.classes = []
        
        self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            self.classes.append(class_name)
            class_dir = os.path.join(root_dir, class_name)

            for fname in os.listdir(class_dir):
                self.samples.append(
                    (os.path.join(class_dir, fname), idx)
                )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
my_dataset = CustomDataset(
    root_dir="../assets/clean_dataset",
    transform=transforms.ToTensor()
    )
print(f"Total samples in dataset: {len(my_dataset)}")

my_train_dataset, my_test_dataset = random_split(my_dataset, [0.8, 0.2])
print(f"Training samples: {len(my_train_dataset)}")
print(f"Testing samples: {len(my_test_dataset)}")

In [None]:
img0, label0 = my_train_dataset[0]
img1, label1 = my_train_dataset[1]

fig, axes = plt.subplots(1, 2, figsize=(6, 3))

axes[0].imshow(img0.permute(1, 2, 0))
axes[0].set_title(my_dataset.classes[label0])
axes[0].axis("off")

axes[1].imshow(img1.permute(1, 2, 0))
axes[1].set_title(my_dataset.classes[label1])
axes[1].axis("off")

In [None]:
folder_dataset = ImageFolder(
    root="../assets/clean_dataset",
    transform=transforms.ToTensor()
)

img0, label0 = folder_dataset[0]
img1, label1 = folder_dataset[1]

fig, axes = plt.subplots(1, 2, figsize=(6, 3))

axes[0].imshow(img0.permute(1, 2, 0))
axes[0].set_title(folder_dataset.classes[label0])
axes[0].axis("off")

axes[1].imshow(img1.permute(1, 2, 0))
axes[1].set_title(folder_dataset.classes[label1])
axes[1].axis("off")

In [None]:
class RobustDataset(Dataset):
    def __init__(self, root_dir, transform=None, min_size=32):
        self.root_dir = root_dir
        self.transform = transform
        self.min_size = min_size

        self.samples = []
        self.classes = []
        self.error_logs = []

        # discover class folders only
        for idx, class_name in enumerate(sorted(os.listdir(root_dir))):
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_dir):
                continue

            self.classes.append(class_name)

            for fname in os.listdir(class_dir):
                path = os.path.join(class_dir, fname)
                if os.path.isfile(path):
                    self.samples.append((path, idx))

        if len(self.samples) == 0:
            raise RuntimeError("No valid image files found")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # try up to len(self) times to find a valid sample
        for _ in range(len(self)):
            path, label = self.samples[idx]
            try:
                image = self.load_and_validate_image(path)

                if self.transform:
                    image = self.transform(image)

                return image, label

            except Exception as e:
                self.log_error(idx, path, e)
                idx = (idx + 1) % len(self)

        raise RuntimeError("All samples appear to be corrupted.")

    def load_and_validate_image(self, path):
        # verify image structure
        with Image.open(path) as img:
            img.verify()

        # reload after verify (required)
        image = Image.open(path)
        image.load()

        # size check
        if image.size[0] < self.min_size or image.size[1] < self.min_size:
            raise ValueError(f"Image too small: {image.size}")

        # color mode check
        if image.mode != "RGB":
            image = image.convert("RGB")

        return image

    def log_error(self, idx, path, e):
        self.error_logs.append(
            {
                "index": idx,
                "path": path,
                "error": str(e),
            }
        )
        print(f"Warning: Skipping corrupted image at {path}: {e}")

    def get_error_summary(self, max_print=5):
        if not self.error_logs:
            print("No errors encountered - dataset is clean.")
            return

        print(f"\nEncountered {len(self.error_logs)} problematic samples:")
        for err in self.error_logs[:max_print]:
            print(f"  [{err['index']}] {err['path']}: {err['error']}")

        if len(self.error_logs) > max_print:
            print(f"  ... and {len(self.error_logs) - max_print} more")

In [None]:
dataset = RobustDataset(
    root_dir="../assets/corrupt_dataset",
    transform=None
)

# force iteration
for i in range(10):
    _ = dataset[i]

dataset.get_error_summary()