
# 03_dataset_and_transforms

In [1]:
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import torch

class ISICSkinDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        image_id = row["isic_id"]
        label = torch.tensor(row["label"], dtype=torch.long)

        image_path = self.image_dir / f"{image_id}.jpg"
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label


In [2]:
from torchvision import transforms

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [3]:
from pathlib import Path

project_root = Path.cwd().parent

train_dataset = ISICSkinDataset(
    csv_file=project_root / "data/processed/train/train_binary.csv",
    image_dir=project_root / "data/raw/train/images_train",
    transform=train_transforms
)

val_dataset = ISICSkinDataset(
    csv_file=project_root / "data/processed/val/val_binary.csv",
    image_dir=project_root / "data/raw/val/images_val",
    transform=val_transforms
)

print(len(train_dataset))  # 9885
print(len(val_dataset))    # 193


9885
193


In [4]:
from torch.utils.data import DataLoader, Subset

train_subset = Subset(train_dataset, range(500))

train_loader = DataLoader(
    train_subset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0
)


In [5]:
images, labels = next(iter(train_loader))
print(images.shape)   # [8, 3, 224, 224]
print(labels)         # tensor of 0s and 1s


torch.Size([8, 3, 224, 224])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
