In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import os
from PIL import Image

**Custom Dataset Class**

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, target_transform=None):
        super().__init__()
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform
        self.target_transform = target_transform

    def len(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')
        label = self._get_label(image_path)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

    def _get_label(self, image_path):
        return 0

**Data Augmentation**

In [None]:
# Training transforms (with augmentation)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


# Test transforms (no augmentation)
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])


train_data = CustomImageDataset(image_dir='train', transform=train_transform)
test_data = CustomImageDataset(image_dir='test', transform=test_transform)