In [7]:
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import os

In [8]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.classes, self.class_to_idx = self._find_classes(self.root_dir)
        self.imgs = self._make_dataset()

    def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        """
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def _make_dataset(self):
        """
        Creates a list of samples with their class indices.
        """
        images = []
        for target in sorted(self.class_to_idx.keys()):
            d = os.path.join(self.root_dir, target)
            if not os.path.isdir(d):
                continue
            for root, _, fnames in sorted(os.walk(d)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    item = (path, self.class_to_idx[target])
                    images.append(item)
        return images

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        path, target = self.imgs[idx]
        # Using cv2 to read and process images
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB

        if self.transform:
            img = self.transform(img)

        return img, target


In [9]:
# Example transform function that can be used with the CustomDataset
def transform(image):
    # Resize image to 224x224
    image = cv2.resize(image, (224, 224))
    # Randomly flip image horizontally
    if np.random.rand() > 0.5:
        image = cv2.flip(image, 1)
    # Convert image to PyTorch tensor and scale to [0,1] (it helps to normalize, Numerical Stability)
    # Matching Activation Functions: Many neural network architectures use activation functions like sigmoid or tanh in their layers. 
    # These functions squeeze their input values into the range [0,1][0,1] 
    # Pre-trained Model Compatibility: Many pre-trained models (e.g., models available in torchvision or TensorFlow model libraries) 
    # expect input images to be scaled to [0,1][0,1]
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    return image

In [10]:
# Usage example
train_dataset = CustomDataset(root_dir='data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Now you can iterate over train_loader in your training loop