# Data Augmentation


Our goal is to select augmentation transformations that simulate realistic variations.
This strategy aims to increase dataset size, improve model generalization and robustness, enhance performance, improve accuracy, address class imbalance, and reduce overfitting.

In [3]:
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import os
import random

class BalancedXRayDataset(Dataset):
    """
    A dataset class for loading, augmenting, and balancing X-ray images.
    
    Attributes:
        x (numpy.ndarray): The array of images.
        y (numpy.ndarray): The array of labels corresponding to the images.
        class_names (list): The list of class names corresponding to label indices.
    """
    
    def __init__(self, x_path, y_path):
        """
        Initializes the dataset by loading images and labels from specified paths.
        
        Args:
            x_path (str): The file path to the numpy array of images.
            y_path (str): The file path to the numpy array of labels.
        """
        self.x = np.load(x_path)
        self.y = np.load(y_path)
        self.class_names = ['Atelectasis', 'Effusion', 'Infiltration', 'Healthy', 'Nodule', 'Pneumothorax']
        
    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.y)

    def __getitem__(self, idx):
        """
        Retrieves an image and its label by index, applying reshaping and conversion to PIL Image.
        
        Args:
            idx (int): The index of the sample to retrieve.
            
        Returns:
            tuple: A tuple containing the image and its label.
            
        The image is reshaped to a 2D array (128x128) .
        """
        image, label = self.x[idx], self.y[idx]
        image = image.reshape((128, 128))  # Reshape, im so obvious
        image = Image.fromarray(image.astype('uint8'), 'L')  # Convert to PIL Image for augmentation. Needed for Transforming.Idk if we haev to save astype but the code works so i dont want to touch it for now
        return image, label

    def apply_augmentation(self, image):
        """
        Applies a series of transformations to augment the image.
        
        Args:
            image (PIL.Image.Image): The image to augment.
            
        Returns:
            PIL.Image.Image: The augmented image.
            
        Defining a series of transformations including Gaussian blur, random rotation, and affine transformations
        to simulate natural variations in X-ray images. These augmentations increase the dataaset and help introduce variability into the dataset,
        making it more robust.
        """
        augmentation_transforms = transforms.Compose([
            transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 0.2)),  # Blurring to simulate focus variations.
            transforms.RandomRotation(degrees=10),  # Random rotations, from what i read it should be within 20 degrees for realistic patient rotation.
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), shear=5),  # Slight translations, scaling, and shearing for variability.
        ])
        return augmentation_transforms(image)

def augment_dataset(dataset, augmentation_targets):
    """
    Augments images from underrepresented classes in the dataset according to specified targets.
    
    Args:
        dataset (BalancedXRayDataset): The dataset instance containing X-ray images and labels.
        augmentation_targets (dict): A dictionary specifying the target increase for each class.
        
    Returns:
        tuple: Two numpy arrays containing augmented images and their labels.
    """
    augmented_images = []
    augmented_labels = []
    
    unique, counts = np.unique(dataset.y, return_counts=True)
    
    for class_id in unique:
        if dataset.class_names[class_id] == 'Healthy':
            continue  # Skip augmentation for 'Healthy'
        
        class_indices = np.where(dataset.y == class_id)[0]
        target_increase = augmentation_targets.get(dataset.class_names[class_id], 0)
        num_to_augment = int(counts[class_id] * target_increase) - counts[class_id]
        
        for _ in range(num_to_augment):
            idx = random.choice(class_indices)
            image, label = dataset[idx]
            augmented_image = dataset.apply_augmentation(image)
            augmented_images.append(np.array(augmented_image)[np.newaxis, :, :])  # Add channel dimension for same dimensionality as original dataaset
            augmented_labels.append(label)
    
    return np.array(augmented_images), np.array(augmented_labels)


def combine_datasets(x_original, y_original, x_augmented, y_augmented):
    """
    Combines original and augmented datasets into a single dataset.
    
    Args:
        x_original (numpy.ndarray): Original images.
        y_original (numpy.ndarray): Original labels.
        x_augmented (numpy.ndarray): Augmented images.
        y_augmented (numpy.ndarray): Labels for augmented images.
        
    Returns:
        tuple: Two numpy arrays containing combined images and their labels.
        
    This function concatenates the original and augmented datasets along the sample axis. It ensures that both
    original and augmented datasets have the same dimensions before concatenation.This is needed in order to concatenate (they have to have the same shape)
    """
    x_combined = np.concatenate([x_original, x_augmented], axis=0)
    y_combined = np.concatenate([y_original, y_augmented], axis=0)
    return x_combined, y_combined

def undersample_healthy_class(x, y, healthy_class_label=3, target_count=4500):
    """
    Undersamples the 'Healthy' class in the dataset to a target count.
    
    Args:
        x (numpy.ndarray): Combined images including original and augmented.
        y (numpy.ndarray): Combined labels including original and augmented.
        healthy_class_label (int): The label index for the 'Healthy' class.
        target_count (int): The target number of samples to retain for the 'Healthy' class.
        
    Returns:
        tuple: Two numpy arrays containing undersampled images and their labels.
        
    This function reduces the number of samples in the 'Healthy' class by randomly selecting and dropping excess samples,
    achieving target count in this case 4500, this amount should be revised.
    """
    healthy_indices = np.where(y == healthy_class_label)[0]
    np.random.shuffle(healthy_indices)
    drop_indices = healthy_indices[target_count:]  # Indices of 'Healthy' samples to drop
    
    x_undersampled = np.delete(x, drop_indices, axis=0)
    y_undersampled = np.delete(y, drop_indices)
    
    return x_undersampled, y_undersampled

# Main function to run the augmentation and analysis
def main():
    x_train_path = '../../data/X_train.npy'
    y_train_path = '../../data/Y_train.npy'
    save_dir = '../../data_augmented_Improved'
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    dataset_instance = BalancedXRayDataset(x_train_path, y_train_path)
    
    augmentation_targets = {
        'Atelectasis': 1.2,
        'Effusion': 1.2,
        'Infiltration': 1.2,
        'Nodule': 1.5,
        'Pneumothorax': 1.5,
    }
    
    x_augmented, y_augmented = augment_dataset(dataset_instance, augmentation_targets)
    
    x_combined, y_combined = combine_datasets(dataset_instance.x, dataset_instance.y, x_augmented, y_augmented)
    
    x_final, y_final = undersample_healthy_class(x_combined, y_combined)
    
    np.save(os.path.join(save_dir, 'X_augmented.npy'), x_final)
    np.save(os.path.join(save_dir, 'Y_augmented.npy'), y_final)
    
    print(f"Final dataset saved with {len(x_final)} images.")

if __name__ == "__main__":
    main()

Final dataset saved with 18264 images.


In [5]:
import numpy as np

def print_augmented_class_distribution(augmented_y_path, class_names):
    # Load augmented labels
    y_augmented = np.load(augmented_y_path)
    
    # Calculate class distribution in the augmented dataset
    unique, counts_augmented = np.unique(y_augmented, return_counts=True)
    
    print("Class Distribution in Augmented Dataset:")
    for class_id, count in zip(unique, counts_augmented):
        print(f"{class_names[class_id]}: {count} samples")

# Paths to the augmented label file
augmented_y_path = "../../data_augmented_Improved/Y_augmented.npy"
class_names = ['Atelectasis', 'Effusion', 'Infiltration', 'Healthy', 'Nodule', 'Pneumothorax']

# Call the function with the correct path
print_augmented_class_distribution(augmented_y_path, class_names)


Class Distribution in Augmented Dataset:
Atelectasis: 3025 samples
Effusion: 2781 samples
Infiltration: 3556 samples
Healthy: 4500 samples
Nodule: 2449 samples
Pneumothorax: 1953 samples


# To be improved

Code where we visualize the original and augmented image

In [6]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image, to_tensor

def visualize_augmentation(dataset, num_samples_per_class=3):
    """
    Visualizes original and augmented images from the dataset for specified classes.
    
    Args:
        dataset (BalancedXRayDataset): The dataset instance containing X-ray images and labels.
        num_samples_per_class (int): Number of samples per class to visualize.
    """
    # Classes to visualize (excluding 'Healthy')
    classes_to_visualize = [0, 1, 2, 4, 5]  # Indices for Atelectasis, Effusion, Infiltration, Nodule, Pneumothorax
    
    # Create a figure with subplots
    num_classes = len(classes_to_visualize)
    fig, axs = plt.subplots(num_classes, num_samples_per_class * 2, figsize=(num_samples_per_class * 8, num_classes * 4))
    
    for i, class_id in enumerate(classes_to_visualize):
        class_indices = np.where(dataset.y == class_id)[0]
        selected_indices = np.random.choice(class_indices, size=num_samples_per_class, replace=False)
        
        for j, idx in enumerate(selected_indices):
            original_image, _ = dataset[idx]
            # Convert PIL Image to NumPy array for plotting
            original_image_np = np.array(original_image)
            
            # Apply augmentation and convert back to NumPy array
            augmented_image = dataset.apply_augmentation(original_image)
            augmented_image_np = np.array(augmented_image)
            
            # Plot original image
            ax = axs[i, j * 2]
            ax.imshow(original_image_np, cmap='gray')
            ax.set_title(f"Original - {dataset.class_names[class_id]}")
            ax.axis('off')
            
            # Plot augmented image
            ax = axs[i, j * 2 + 1]
            ax.imshow(augmented_image_np, cmap='gray')
            ax.set_title(f"Augmented - {dataset.class_names[class_id]}")
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage:
# Assuming 'dataset_instance' is an instance of BalancedXRayDataset already created
