## 🧠 Introduction to Patch Datasets in Pathology

Brief overview of why we work with patches (due to the large size of Whole Slide Images - WSIs), common challenges (stain variation, artifacts), and the importance of proper dataset handling in computational pathology for tasks like classification or segmentation. Explain that this notebook focuses on creating a robust data pipeline for patch-level classification.

## Notebook Plan

This notebook will guide you through the essential steps of preparing a patch-based image dataset for a classification task:

1.  **Setup & Configuration:** Import necessary libraries (PyTorch, Lightning, Albumentations, scikit-learn, Plotly) and configure paths and class names for our dataset.
2.  **Exploring Augmentation (Albumentations):**
    *   Load a sample image.
    *   Introduce `albumentations` and demonstrate various image transforms (flips, rotation, color jitter).
    *   Visualize the effect of these transforms on the sample image.
3.  **Defining Data Transforms:** Create specific augmentation pipelines using `albumentations.Compose` for the training set (including augmentations) and the validation/test sets (typically only normalization and tensor conversion).
4.  **Implementing the `PatchDataset`:** Build a custom `torch.utils.data.Dataset` class capable of:
    *   Loading image paths and corresponding labels.
    *   Applying the defined transforms.
    *   Returning image tensors and label indices.
5.  **Stratified Data Splitting:** Use `sklearn.model_selection.train_test_split` to divide the dataset into training, validation, and test sets while preserving the original class distribution (stratification).
6.  **Creating Datasets:** Instantiate the `PatchDataset` for each split (train, validation, test) using the appropriate file lists and transforms.
7.  **Implementing `DataLoader`:**
    *   Explain the role of `torch.utils.data.DataLoader` for batching, shuffling, and parallel loading.
    *   Create `DataLoader` instances for the train and validation sets.
8.  **Visualizing a Batch:** Define and use a helper function to load a batch from the `DataLoader` and display the images (with denormalization) in a grid to verify the pipeline.

In [None]:

# --- Common Imports ---
import os
import glob
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch import ToTensorV2 # Important for Albumentations with PyTorch
from PIL import Image # For loading images
import cv2 # Often used by Albumentations backend
from sklearn.model_selection import train_test_split

# --- Visualization (using Plotly) ---
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

# Set default Plotly template
pio.templates.default = "plotly_white"

# --- Reproducibility ---
# Set random seeds for reproducibility
pl.seed_everything(42, workers=True)

# --- Configuration ---
# Define the path to your raw dataset
# Ensure your dataset (e.g., NSCLC IHC images) is in this directory
# with subfolders for each class ('TER', 'Necrotic', etc.)
DATA_DIR = Path('../data/raw/patch_classification_dataset')

# Automatically get class names from subfolder names
try:
    CLASS_NAMES = sorted([p.name for p in DATA_DIR.glob('*') if p.is_dir()])
    if not CLASS_NAMES:
        raise FileNotFoundError # Handle case where directory exists but is empty or has no subdirs
    NUM_CLASSES = len(CLASS_NAMES)

    print(f'Dataset directory: {DATA_DIR}')
    print(f'Found {NUM_CLASSES} classes: {CLASS_NAMES}')

    # Create mappings
    class_to_idx = {name: i for i, name in enumerate(CLASS_NAMES)}
    idx_to_class = {i: name for name, i in class_to_idx.items()}

    print(f'Class to index mapping: {class_to_idx}')

except FileNotFoundError:
     print(f'Error: Dataset directory not found or no class subfolders found at {DATA_DIR}')
     print('Please ensure your dataset is placed correctly as per the README instructions.')
     # You might want to stop execution here in a real notebook if the data isn't found
     CLASS_NAMES = []
     NUM_CLASSES = 0
     class_to_idx = {}
     idx_to_class = {}

# Remember: Keep code clean, well-commented, and follow best practices.

## 🔄 Albumentations: Transforms & Visualization

Data augmentation is crucial for training robust deep learning models, especially when dealing with limited medical imaging data. Albumentations is a powerful library specifically designed for fast and flexible image augmentations.

Let's explore some common transformations and visualize their effects on a single sample image.

In [None]:
# --- Load a single sample image for demonstration ---
sample_image_path = None
if CLASS_NAMES and DATA_DIR.exists():
    first_class_dir = DATA_DIR / CLASS_NAMES[0]
    try:
        # Find the first PNG image in the first class directory
        sample_image_path = next(first_class_dir.glob('*.png'))
        print(f"Loading sample image: {sample_image_path}")
        sample_image_pil = Image.open(sample_image_path).convert('RGB')
        sample_image_np = np.array(sample_image_pil)

        # Visualize the original sample
        fig = px.imshow(sample_image_np, title=f"Original Sample ({CLASS_NAMES[0]})")
        fig.update_layout(coloraxis_showscale=False).update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
        fig.show()

    except StopIteration:
        print(f"Error: No PNG images found in the first class directory: {first_class_dir}")
        sample_image_np = None # Ensure variable exists but is None
    except Exception as e:
        print(f"Error loading sample image {sample_image_path}: {e}")
        sample_image_np = None
else:
    print("Cannot load sample image: Data directory or class subfolders not found or empty.")
    sample_image_np = None


In [None]:
# --- Define some example Albumentations transforms ---

# 1. Simple Horizontal Flip
transform_flip = A.Compose([
    A.HorizontalFlip(p=1.0), # p=1.0 means always apply
])

# 2. Rotation
transform_rotate = A.Compose([
    A.Rotate(limit=45, p=1.0, border_mode=cv2.BORDER_CONSTANT, value=0), # Rotate up to 45 degrees
])

# 3. Color Jitter
transform_color = A.Compose([
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=1.0),
])

# 4. Combined Transforms (more realistic)
# Note: We add ToTensorV2() at the end when using with PyTorch/Lightning
transform_combined = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0),
    A.RandomBrightnessContrast(p=0.3),
    A.GaussNoise(p=0.2),
    # IMPORTANT: Resize if your model expects a fixed input size
    # A.Resize(height=224, width=224),
    # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example: ImageNet normalization
    # ToTensorV2() # Converts image to PyTorch tensor (CHW format) - ADD THIS LATER IN DATASET/DATAMODULE
])


# --- Apply transforms to the sample image ---
if sample_image_np is not None:
    augmented_flip = transform_flip(image=sample_image_np)['image']
    augmented_rotate = transform_rotate(image=sample_image_np)['image']
    augmented_color = transform_color(image=sample_image_np)['image']
    augmented_combined = transform_combined(image=sample_image_np)['image']
    print("Applied various transformations to the sample image.")
else:
    print("Skipping transform application as sample image failed to load.")
    # Assign None to prevent errors in the next cell
    augmented_flip, augmented_rotate, augmented_color, augmented_combined = None, None, None, None


# --- Visualize the results ---
if sample_image_np is not None and augmented_combined is not None:
    fig = make_subplots(rows=1, cols=5, subplot_titles=("Original", "Flipped", "Rotated", "Color Jitter", "Combined"))

    # Add original image
    fig.add_trace(px.imshow(sample_image_np).data[0], row=1, col=1)
    # Add augmented images
    fig.add_trace(px.imshow(augmented_flip).data[0], row=1, col=2)
    fig.add_trace(px.imshow(augmented_rotate).data[0], row=1, col=3)
    fig.add_trace(px.imshow(augmented_color).data[0], row=1, col=4)
    fig.add_trace(px.imshow(augmented_combined).data[0], row=1, col=5)

    fig.update_layout(title_text="Albumentations Examples", height=400, width=1200)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.update_layout(coloraxis_showscale=False)
    fig.show()
else:
    print("Cannot visualize augmentations because the sample image or augmented images are missing.")
    

We've applied several common augmentations:
*   **Geometric:** `HorizontalFlip`, `VerticalFlip`, `Rotate`. These help the model become invariant to object orientation.
*   **Color:** `ColorJitter`, `RandomBrightnessContrast`. These help the model generalize to variations in staining and illumination.
*   **Noise:** `GaussNoise`. Can improve robustness.

**Important Considerations:**
*   **`p` parameter:** Controls the probability of applying a transform (e.g., `p=0.5` means 50% chance).
*   **`Compose`:** Chains multiple transforms together. The order can sometimes matter.
*   **`ToTensorV2()`:** This essential transform converts the NumPy array (HWC format, 0-255) to a PyTorch tensor (CHW format, usually 0-1) and should typically be the *last* transform in the sequence when preparing data for PyTorch models. We will add it later when integrating with our `Dataset`.
*   **Normalization:** Often applied after `ToTensorV2`. Standard practice is to use statistics (mean, std dev) from large datasets like ImageNet, or calculate them from your own dataset.

Now that we understand how Albumentations works, we can integrate these transforms into our `PatchDataset` class.

## 📦 PyTorch Dataset Implementation

Here we will implement a custom `torch.utils.data.Dataset` class to load our patch images and corresponding labels efficiently. This class will handle finding image paths, loading images, applying transformations (if any), and returning image-label pairs.

In [None]:
# --- Define Transforms ---

# Define image size if resizing is needed
# IMG_HEIGHT = 224
# IMG_WIDTH = 224

# Define normalization constants (e.g., ImageNet stats)
# Calculate these from your specific dataset for potentially better results if needed
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]

# Define transforms for the training set (including augmentation)
train_transforms = A.Compose([
    # A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH), # Uncomment if resizing needed
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3), # Color jitter alternative
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
    # Normalization should come after color augmentations and before ToTensorV2
    A.Normalize(mean=IMG_MEAN, std=IMG_STD),
    ToTensorV2(), # Converts image to PyTorch tensor (CHW format) and scales to [0, 1] if input is uint8
])

# Define transforms for validation and test sets (usually no augmentation, just normalization and tensor conversion)
val_test_transforms = A.Compose([
    # A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH), # Uncomment if resizing needed
    A.Normalize(mean=IMG_MEAN, std=IMG_STD),
    ToTensorV2(),
])

print("Defined train_transforms and val_test_transforms.")

# Optional: Visualize an image after applying train_transforms
if sample_image_np is not None:
    transformed_sample = train_transforms(image=sample_image_np)['image']
    print("Sample image shape after train_transforms (should be CHW tensor):", transformed_sample.shape)
    # Note: Visualization of the normalized tensor might look strange directly with imshow
    # We might need to denormalize or just show one channel if needed later.
else:
    print("Cannot apply train_transforms as sample_image_np is None.")


In [None]:
# --- Modified PatchDataset Class ---
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class PatchDataset(Dataset):
    """
    PyTorch Dataset for patch classification.
    Accepts a list of filepaths and corresponding labels.
    """
    def __init__(self, filepaths, labels, transform=None, label_map=None):
        """
        Args:
            filepaths (list): List of paths to image files.
            labels (list): List of corresponding labels (strings).
            transform (callable, optional): Optional transform to be applied on a sample.
            label_map (dict, optional): Dictionary mapping string labels to integer indices.
        """
        self.filepaths = filepaths
        self.labels = labels
        self.transform = transform

        # Create or use label map
        if label_map:
            self.label_map = label_map
        else:
            # Create map from unique labels found
            unique_labels = sorted(list(set(labels)))
            self.label_map = {label: idx for idx, label in enumerate(unique_labels)}

        self.idx_to_label = {v: k for k, v in self.label_map.items()} # For potential reverse lookup

        print(f"Initialized Dataset. Label map: {self.label_map}")

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        img_path = self.filepaths[idx]
        label_str = self.labels[idx]
        label_idx = self.label_map[label_str]

        # Load image using PIL (ensure RGB)
        image = Image.open(img_path).convert('RGB')

        # Apply transforms if they exist
        if self.transform:
            # Albumentations requires numpy array
            image_np = np.array(image)
            augmented = self.transform(image=image_np)
            image_tensor = augmented['image'] # Albumentations returns a dict
        else:
            # Basic transform to tensor if no augmentation
            # Note: Albumentations ToTensorV2 normalizes by default,
            # this basic conversion does not. Add normalization if needed.
            image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0

        return image_tensor, label_idx

    def get_label_map(self):
        return self.label_map

    def get_idx_to_label(self):
        return self.idx_to_label


In [None]:
# --- Stratified Train/Validation/Test Split ---

# Define split ratios
TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15 # VAL_RATIO + TEST_RATIO should equal 1.0 - TRAIN_RATIO

all_image_paths = []
all_labels = []

print("Scanning for all image paths and labels...")
if DATA_DIR.exists() and CLASS_NAMES:
    for class_name, idx in class_to_idx.items():
        class_dir = DATA_DIR / class_name
        if class_dir.is_dir():
            for img_path in glob.glob(str(class_dir / '*.png')): # Adjust pattern if needed
               all_image_paths.append(Path(img_path))
               all_labels.append(idx)
    print(f"Found {len(all_image_paths)} total images.")
else:
    print("Error: Cannot perform split. Data directory or class folders not found/empty.")
    # Assign empty lists to prevent errors later
    all_image_paths = []
    all_labels = []


# Perform the split only if images were found
if all_image_paths:
    # First split: separate train set from (validation + test) set
    # Ensure stratify is only used if there are enough samples per class for the split
    try:
        train_paths, temp_paths, train_labels, temp_labels = train_test_split(
            all_image_paths,
            all_labels,
            test_size=(VAL_RATIO + TEST_RATIO), # Size of the temporary set
            random_state=42, # for reproducibility
            stratify=all_labels # Ensure class distribution is similar
        )

        # Second split: separate validation set from test set
        # Adjust test_size relative to the *temporary* set size
        relative_test_size = TEST_RATIO / (VAL_RATIO + TEST_RATIO)

        val_paths, test_paths, val_labels, test_labels = train_test_split(
            temp_paths,
            temp_labels,
            test_size=relative_test_size,
            random_state=42, # for reproducibility
            stratify=temp_labels # Ensure class distribution is similar
        )

        print(f"\nDataset split:")
        print(f"  Train samples: {len(train_paths)}")
        print(f"  Validation samples: {len(val_paths)}")
        print(f"  Test samples: {len(test_paths)}")

        # Verify stratification (optional but recommended)
        from collections import Counter
        print("\nClass distribution:")
        print(f"  Overall: {Counter(all_labels)}")
        print(f"  Train:   {Counter(train_labels)}")
        print(f"  Val:     {Counter(val_labels)}")
        print(f"  Test:    {Counter(test_labels)}")

    except ValueError as e:
         print(f"\nError during stratified split: {e}")
         print("This might happen if a class has too few samples for the requested split ratios.")
         print("Assigning all data to train set as fallback (adjust as needed).")
         train_paths, val_paths, test_paths = all_image_paths, [], []
         train_labels, val_labels, test_labels = all_labels, [], []

else:
    print("Skipping split as no images were loaded.")
    train_paths, val_paths, test_paths = [], [], []
    train_labels, val_labels, test_labels = [], [], []


In [None]:
# --- Create Datasets using the split file lists ---

# Make sure train_transform and val_test_transform are defined earlier

# Check if paths lists are not empty before creating datasets
if train_paths:
    train_dataset = PatchDataset(
        filepaths=train_paths, # CHANGE HERE
        labels=train_labels,   # CHANGE HERE
        transform=train_transforms
    )
    # Use the label map from the training set for consistency
    label_map = train_dataset.get_label_map()
    print(f"\nTrain dataset size: {len(train_dataset)}")
else:
    train_dataset = None
    label_map = {} # Define an empty map if no training data
    print("\nWarning: No training data found. Cannot create training dataset.")

if val_paths:
    val_dataset = PatchDataset(
        filepaths=val_paths,     # CHANGE HERE
        labels=val_labels,       # CHANGE HERE
        transform=val_test_transforms, # Use non-augmenting transforms for validation
        label_map=label_map # Use map from training data
    )
    print(f"Validation dataset size: {len(val_dataset)}")
else:
    val_dataset = None
    print("Warning: No validation data found. Cannot create validation dataset.")


if test_paths:
    test_dataset = PatchDataset(
        filepaths=test_paths,    # CHANGE HERE
        labels=test_labels,      # CHANGE HERE
        transform=val_test_transforms, # Use non-augmenting transforms for test
        label_map=label_map # Use map from training data
    )
    print(f"Test dataset size: {len(test_dataset)}")
else:
    test_dataset = None
    print("Warning: No test data found. Cannot create test dataset.")


## 🚚 PyTorch DataLoader Implementation

Now that we have a `Dataset`, we need a `DataLoader` to efficiently load batches of data during training. The `DataLoader` handles several key aspects:

*   **Batching:** Grouping individual samples into mini-batches.
*   **Shuffling:** Randomly shuffling the data order at each epoch (important for training).
*   **Parallel Loading:** Using multiple worker processes (`num_workers`) to load data in the background, preventing the GPU from waiting for data.
*   **Memory Pinning:** Optionally pinning memory (`pin_memory=True`) for faster CPU-to-GPU transfers.

We will create a basic DataLoader example first and then see how to integrate it properly within a LightningDataModule later.

In [None]:
# --- Create DataLoaders ---
batch_size = 32 # Example batch size

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) # For later use

print(f"\nTrain DataLoader batches: {len(train_dataloader)}")
print(f"Validation DataLoader batches: {len(val_dataloader)}")

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils

def denormalize(tensor, mean=IMG_MEAN, std=IMG_STD):
    """Denormalizes a tensor image with mean and standard deviation."""
    # Clone to avoid modifying the original tensor
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)  # Reverse the normalization: (tensor * std) + mean
    # We need to clamp values to [0, 1] after denormalization
    tensor = torch.clamp(tensor, 0, 1)
    return tensor

# --- Visualize a batch from the DataModule's DataLoader ---

# Ensure the show_batch_grid function is defined earlier
# Make sure it uses the potentially updated idx_to_class_from_module map

def show_batch_grid(dataloader, num_images=16, title="Sample Batch", idx_map=None): # Added idx_map
    """Fetches one batch and displays it using torchvision.utils.make_grid and matplotlib."""
    if not dataloader:
        print("DataLoader is None, cannot show batch.")
        return
    if not idx_map:
        print("Index-to-class map not provided.")
        idx_map = {} # Default to empty map

    try:
        images, labels = next(iter(dataloader))
    except StopIteration:
        print("DataLoader is empty or exhausted.")
        return
    except Exception as e:
        print(f"Error fetching batch: {e}")
        return

    images_to_show = images[:num_images]
    labels_to_show = labels[:num_images]
    denormalized_images = [denormalize(img) for img in images_to_show]
    grid = torchvision.utils.make_grid(denormalized_images, nrow=int(math.sqrt(num_images)))
    grid_np = grid.numpy()
    grid_display = np.transpose(grid_np, (1, 2, 0))

    plt.figure(figsize=(10, 10))
    plt.imshow(grid_display)
    plt.title(title)
    plt.axis('off')
    plt.show()

    print("Labels for displayed batch:")
    print([idx_map.get(l.item(), "Unknown") for l in labels_to_show])






## ⚡ PyTorch Lightning DataModule

To streamline the data loading process and integrate seamlessly with PyTorch Lightning's `Trainer`, we encapsulate all the logic (finding files, splitting, creating datasets, defining dataloaders) into a `LightningDataModule`.

**Key Benefits:**
*   **Organization:** Keeps all data-related code in one place.
*   **Reproducibility:** Ensures data splitting and loading are consistent.
*   **Decoupling:** Separates data logic from the model definition (`LightningModule`).
*   **Trainer Integration:** The `Trainer` automatically calls the appropriate methods (`prepare_data`, `setup`, `train_dataloader`, etc.).

We will define a `PatchDataModule` that handles our specific patch classification dataset.

In [None]:
# --- LightningDataModule Implementation ---
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import glob # Ensure glob is imported if not already in this cell context
from pathlib import Path # Ensure Path is imported

class PatchDataModule(pl.LightningDataModule):
    def __init__(self,
                 data_dir: str,
                 class_to_idx: dict,
                 train_transform: A.Compose,
                 val_test_transform: A.Compose,
                 batch_size: int = 32,
                 num_workers: int = 4,
                 train_ratio: float = 0.7,
                 val_ratio: float = 0.15,
                 test_ratio: float = 0.15,
                 seed: int = 42):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.class_to_idx = class_to_idx
        self.idx_to_class = {v: k for k, v in class_to_idx.items()}
        self.train_transform = train_transform
        self.val_test_transform = val_test_transform
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.seed = seed

        # Placeholders for datasets and file paths after setup
        self.train_paths, self.val_paths, self.test_paths = None, None, None
        self.train_labels, self.val_labels, self.test_labels = None, None, None
        self.train_dataset, self.val_dataset, self.test_dataset = None, None, None
        self.label_map = None # Will be determined by train_dataset

        # Save hyperparameters for logging (optional but good practice)
        self.save_hyperparameters(ignore=['train_transform', 'val_test_transform', 'class_to_idx']) # Avoid saving complex objects

    def prepare_data(self):
        # Called only on 1 GPU/TPU in distributed settings.
        # Use this to download data, check existence etc.
        if not self.data_dir.exists():
             raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
        print(f"Data directory check passed: {self.data_dir}")
        # We could add more checks here if needed

    def setup(self, stage: str = None):
        # Called on every GPU/TPU in distributed settings.
        # Assign train/val/test datasets for use in dataloaders
        # `stage` can be 'fit', 'validate', 'test', 'predict'
        if not self.train_dataset and not self.val_dataset and not self.test_dataset:
            print(f"Setting up data for stage: {stage}")

            # --- Scan & Split ---
            all_image_paths = []
            all_labels_str = [] # Use string labels for splitting/map creation
            class_names = sorted(list(self.class_to_idx.keys()))

            print("Scanning for all image paths and string labels...")
            if self.data_dir.exists() and class_names:
                 for class_name in class_names:
                     class_dir = self.data_dir / class_name
                     if class_dir.is_dir():
                         for img_path in glob.glob(str(class_dir / '*.png')): # Adjust pattern if needed
                            all_image_paths.append(Path(img_path))
                            all_labels_str.append(class_name) # Store string label
                 print(f"Found {len(all_image_paths)} total images.")
            else:
                 print("Error: Cannot perform split. Data directory or class folders not found/empty.")
                 # Handle error or return if necessary
                 return

            if not all_image_paths:
                print("No images found, cannot proceed with setup.")
                return

            # --- Perform Splits ---
            try:
                # Split 1: Train vs Temp (Val+Test)
                self.train_paths, temp_paths, self.train_labels, temp_labels = train_test_split(
                    all_image_paths, all_labels_str, # Use string labels
                    test_size=(self.val_ratio + self.test_ratio),
                    random_state=self.seed,
                    stratify=all_labels_str
                )
                # Split 2: Val vs Test from Temp
                relative_test_size = self.test_ratio / (self.val_ratio + self.test_ratio)
                self.val_paths, self.test_paths, self.val_labels, self.test_labels = train_test_split(
                    temp_paths, temp_labels, # Use string labels
                    test_size=relative_test_size,
                    random_state=self.seed,
                    stratify=temp_labels
                )
                print("Dataset split completed.")
                print(f"Train size: {len(self.train_paths)}, Val size: {len(self.val_paths)}, Test size: {len(self.test_paths)}")

            except ValueError as e:
                print(f"Error during stratified split: {e}. Check class distribution and split ratios.")
                # Implement fallback or raise error
                return

            # --- Instantiate Datasets ---
            if self.train_paths:
                self.train_dataset = PatchDataset(
                    filepaths=self.train_paths,
                    labels=self.train_labels,
                    transform=self.train_transform
                    # Let PatchDataset create the label_map from these string labels
                )
                self.label_map = self.train_dataset.get_label_map() # Get the map

            if self.val_paths and self.label_map is not None:
                self.val_dataset = PatchDataset(
                    filepaths=self.val_paths,
                    labels=self.val_labels,
                    transform=self.val_test_transform,
                    label_map=self.label_map # Use map from train set
                )

            if self.test_paths and self.label_map is not None:
                self.test_dataset = PatchDataset(
                    filepaths=self.test_paths,
                    labels=self.test_labels,
                    transform=self.val_test_transform,
                    label_map=self.label_map # Use map from train set
                )

            print("Datasets instantiated.")

    def train_dataloader(self):
        if self.train_dataset:
            return DataLoader(self.train_dataset,
                              batch_size=self.batch_size,
                              shuffle=True,
                              num_workers=self.num_workers,
                              pin_memory=True,
                              persistent_workers=True if self.num_workers > 0 else False) # Good practice
        return None

    def val_dataloader(self):
        if self.val_dataset:
            return DataLoader(self.val_dataset,
                              batch_size=self.batch_size,
                              shuffle=False,
                              num_workers=self.num_workers,
                              pin_memory=True,
                              persistent_workers=True if self.num_workers > 0 else False)
        return None

    def test_dataloader(self):
        if self.test_dataset:
            return DataLoader(self.test_dataset,
                              batch_size=self.batch_size,
                              shuffle=False,
                              num_workers=self.num_workers,
                              pin_memory=True,
                              persistent_workers=True if self.num_workers > 0 else False)
        return None

    # Helper to get label map easily
    def get_label_map(self):
        if self.label_map is None and self.train_dataset:
             self.label_map = self.train_dataset.get_label_map()
        return self.label_map

    def get_idx_to_label(self):
         label_map = self.get_label_map()
         if label_map:
             return {v: k for k, v in label_map.items()}
         return {}

In [None]:
# --- Instantiate the DataModule ---

# Ensure your transforms and class_to_idx map are defined earlier

# Example instantiation (adjust parameters as needed)
data_module = PatchDataModule(
    data_dir=DATA_DIR,
    class_to_idx=class_to_idx, # Use the map created in the config cell
    train_transform=train_transforms,
    val_test_transform=val_test_transforms,
    batch_size=32,
    num_workers=4, # Adjust based on your system
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO
)

# --- Trigger the setup process ---
# This performs the splitting and creates the internal datasets
data_module.prepare_data() # Check if data dir exists
data_module.setup()        # Perform split and dataset creation

# --- Get DataLoaders from the module ---
train_dl = data_module.train_dataloader()
val_dl = data_module.val_dataloader()
test_dl = data_module.test_dataloader() # For later

# Verify
if train_dl:
    print(f"\nSuccessfully obtained Train Dataloader with {len(train_dl)} batches.")
if val_dl:
    print(f"Successfully obtained Validation Dataloader with {len(val_dl)} batches.")

# Update idx_to_class map from the DataModule's potential map
idx_to_class_from_module = data_module.get_idx_to_label()


In [None]:
if train_dl:
    print("\nVisualizing batch from Training DataLoader (via DataModule):")
    show_batch_grid(train_dl, title="Sample Training Batch (DataModule)", idx_map=idx_to_class_from_module)
else:
    print("\nCannot visualize training batch as train_dl is None.")

if val_dl:
    print("\nVisualizing batch from Validation DataLoader (via DataModule):")
    # Note: Validation batches won't show augmentations like flips/rotations
    show_batch_grid(val_dl, title="Sample Validation Batch (DataModule)", idx_map=idx_to_class_from_module)
else:
    print("\nCannot visualize validation batch as val_dl is None.")