In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

## Definimos la clase Dataset 

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # Obtener las clases (subdirectorios) y asignar un número
        self.class_names = sorted(os.listdir(directory))
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}
        
        # Recorremos las carpetas y archivos
        for class_name in self.class_names:
            class_path = os.path.join(directory, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith(('.jpg', '.png')):  # Filtra los tipos de archivo
                        self.image_paths.append(os.path.join(class_path, file_name))
                        self.labels.append(self.class_to_idx[class_name])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        img = Image.open(img_path)
        
        if self.transform:
            img = self.transform(img)
        
        return img, label

In [3]:
original_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Redimensiona la imagen
    transforms.ToTensor(),  # Convierte la imagen a un tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normaliza la imagen
])

augmented_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Redimensiona la imagen
    transforms.RandomHorizontalFlip(),  # Aplica un flip horizontal aleatorio
    transforms.RandomRotation(30),  # Rota aleatoriamente la imagen entre -30 y 30 grados
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Cambia aleatoriamente el brillo, el contraste, la saturación y el matiz
    transforms.ToTensor(),  # Convierte la imagen a un tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normaliza la imagen
])

In [None]:
train_original_dataset = CustomImageDataset(directory='archive/train', transform=original_transform)
train_augmented_dataset = CustomImageDataset(directory='archive/train', transform=augmented_transform)

valid_original_dataset = CustomImageDataset(directory='archive/valid', transform=original_transform)

test_original_dataset = CustomImageDataset(directory='archive/test', transform=original_transform)

In [None]:
train_original_loader = DataLoader(dataset=train_original_dataset, batch_size=32, shuffle=True)
train_augmented_loader = DataLoader(dataset=train_augmented_dataset, batch_size=32, shuffle=True)

valid_original_loader = DataLoader(dataset=valid_original_dataset, batch_size=32, shuffle=False)

test_original_loader = DataLoader(dataset=test_original_dataset, batch_size=32, shuffle=False)