# Data Augmentation Visualization

This notebook visualizes the augmentation pipeline to ensure transformations are appropriate for medical images.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image

from src.data.augmentation import get_train_augmentation, get_val_augmentation, get_strong_augmentation

%matplotlib inline

## Load Sample Image

In [None]:
# Load a sample image from your dataset
# Update this path to point to an actual image
sample_image_path = Path('../data/processed/CE')  # Modify to actual image path

# Find first image
image_files = list(sample_image_path.rglob('*.png')) + list(sample_image_path.rglob('*.jpg'))
if image_files:
    image = np.array(Image.open(image_files[0]).convert('RGB'))
    print(f"Loaded image: {image_files[0]}")
    print(f"Image shape: {image.shape}")
else:
    # Create dummy image for demonstration
    print("No images found. Creating dummy image.")
    image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)

# Display original
plt.figure(figsize=(5, 5))
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
plt.show()

## Visualize Training Augmentation

In [None]:
# Get training augmentation pipeline
train_transform = get_train_augmentation(image_size=224, p=1.0)

# Apply multiple times to see variation
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, ax in enumerate(axes):
    augmented = train_transform(image=image)['image']
    
    # Convert tensor to numpy for display
    if len(augmented.shape) == 3:
        # Tensor format: (C, H, W)
        aug_display = augmented.permute(1, 2, 0).numpy()
        # Denormalize for display
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        aug_display = std * aug_display + mean
        aug_display = np.clip(aug_display, 0, 1)
    else:
        aug_display = augmented
    
    ax.imshow(aug_display)
    ax.set_title(f'Augmentation {i+1}')
    ax.axis('off')

plt.suptitle('Training Augmentation Variations', fontsize=16)
plt.tight_layout()
plt.show()

## Visualize Individual Transformations

In [None]:
import albumentations as A

# Define individual transforms
transforms = [
    ('Original', None),
    ('Horizontal Flip', A.HorizontalFlip(p=1.0)),
    ('Rotation 30Â°', A.Rotate(limit=30, p=1.0)),
    ('Brightness/Contrast', A.RandomBrightnessContrast(p=1.0)),
    ('Gaussian Noise', A.GaussNoise(p=1.0)),
    ('Gaussian Blur', A.GaussianBlur(p=1.0)),
    ('Elastic Transform', A.ElasticTransform(alpha=1, sigma=50, p=1.0)),
    ('Shift/Scale/Rotate', A.ShiftScaleRotate(p=1.0))
]

# Visualize each transform
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, (name, transform) in enumerate(transforms):
    if transform is None:
        transformed = image
    else:
        transformed = transform(image=image)['image']
    
    axes[i].imshow(transformed)
    axes[i].set_title(name)
    axes[i].axis('off')

plt.suptitle('Individual Augmentation Effects', fontsize=16)
plt.tight_layout()
plt.show()

## Compare Normal vs Strong Augmentation

In [None]:
# Get both augmentation pipelines
normal_aug = get_train_augmentation(image_size=224, p=0.8)
strong_aug = get_strong_augmentation(image_size=224)

# Compare side by side
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i in range(4):
    # Original
    axes[0, i].imshow(image)
    axes[0, i].set_title(f'Original')
    axes[0, i].axis('off')
    
    # Normal augmentation
    normal_result = normal_aug(image=image)['image']
    if len(normal_result.shape) == 3:
        normal_display = normal_result.permute(1, 2, 0).numpy()
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        normal_display = std * normal_display + mean
        normal_display = np.clip(normal_display, 0, 1)
    axes[1, i].imshow(normal_display)
    axes[1, i].set_title('Normal Aug')
    axes[1, i].axis('off')
    
    # Strong augmentation
    strong_result = strong_aug(image=image)['image']
    if len(strong_result.shape) == 3:
        strong_display = strong_result.permute(1, 2, 0).numpy()
        strong_display = std * strong_display + mean
        strong_display = np.clip(strong_display, 0, 1)
    axes[2, i].imshow(strong_display)
    axes[2, i].set_title('Strong Aug')
    axes[2, i].axis('off')

plt.suptitle('Normal vs Strong Augmentation Comparison', fontsize=16)
plt.tight_layout()
plt.show()

## Validation: Check Augmentation Doesn't Change Diagnosis

**Important**: Visually inspect augmented images to ensure:
1. Key diagnostic features remain visible
2. Augmentations don't introduce unrealistic artifacts
3. Intensity ranges are reasonable
4. Medical relevance is preserved

## Recommendations

Based on the visualizations above, document:
- Which augmentations are most appropriate for your data
- Any augmentations that should be disabled or adjusted
- Optimal probability values for each transformation
- Whether strong augmentation is appropriate for your use case