# Day 2 - Image Enhancement Module

**Goal:** Apply image enhancement techniques to improve MRI contrast and clarity for better model training.

**Techniques:**
1. Gaussian Blur - Noise reduction
2. Non-Local Means (NLM) Denoising - Advanced noise reduction while preserving edges
3. CLAHE - Adaptive histogram equalization for contrast enhancement

**Outcome:** Enhanced dataset ready for CNN training

## 1. Import Libraries

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import os
import random

# Set matplotlib to display inline
%matplotlib inline

print("✓ All libraries imported successfully")

## 2. Load a Sample Image

In [None]:
# Pick a sample image
img_path = "../outputs/ce_mri_images/1/pid100360_1.png"

# Check if file exists, otherwise pick any file
if not os.path.exists(img_path):
    all_images = glob.glob("../outputs/ce_mri_images/1/*.png")
    if all_images:
        img_path = all_images[0]
    else:
        print("No images found!")

img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

print(f"Loaded: {os.path.basename(img_path)}")
print(f"Shape: {img.shape}")
print(f"Data type: {img.dtype}")
print(f"Value range: [{img.min()}, {img.max()}]")

In [None]:
# Visualize original image
plt.figure(figsize=(6, 6))
plt.imshow(img, cmap='gray')
plt.title(f'Original MRI\n{os.path.basename(img_path)}')
plt.axis('off')
plt.colorbar()
plt.show()

## 3. Experiment with Enhancement Techniques

### 3.1 Gaussian Blur (Simple Noise Reduction)

In [None]:
# Apply Gaussian blur with kernel size (3,3)
gaussian = cv2.GaussianBlur(img, (3, 3), 0)

print(f"Original std dev (noise level): {img.std():.2f}")
print(f"Gaussian std dev: {gaussian.std():.2f}")
print(f"Noise reduction: {((img.std() - gaussian.std()) / img.std() * 100):.1f}%")

### 3.2 Non-Local Means Denoising (Advanced)

In [None]:
# Apply Non-Local Means denoising
# h: filter strength (higher = more smoothing)
# templateWindowSize: size of template patch
# searchWindowSize: size of search area
nlm = cv2.fastNlMeansDenoising(img, None, h=10, templateWindowSize=7, searchWindowSize=21)

print(f"NLM std dev: {nlm.std():.2f}")
print(f"Noise reduction: {((img.std() - nlm.std()) / img.std() * 100):.1f}%")

### 3.3 CLAHE (Contrast Limited Adaptive Histogram Equalization)

In [None]:
# Apply CLAHE on denoised image
# clipLimit: threshold for contrast limiting (2.0 is good balance)
# tileGridSize: size of grid for histogram equalization
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(nlm)

print(f"Original contrast (std): {img.std():.2f}")
print(f"Enhanced contrast (std): {enhanced.std():.2f}")
print(f"Contrast improvement: {((enhanced.std() - img.std()) / img.std() * 100):.1f}%")

## 4. Visual Comparison

In [None]:
# Plot all enhancement stages
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

axes[0, 0].imshow(img, cmap='gray')
axes[0, 0].set_title(f'Original\nStd: {img.std():.2f}', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

axes[0, 1].imshow(gaussian, cmap='gray')
axes[0, 1].set_title(f'Gaussian Blur\nStd: {gaussian.std():.2f}', fontsize=12, fontweight='bold')
axes[0, 1].axis('off')

axes[1, 0].imshow(nlm, cmap='gray')
axes[1, 0].set_title(f'NLM Denoised\nStd: {nlm.std():.2f}', fontsize=12, fontweight='bold')
axes[1, 0].axis('off')

axes[1, 1].imshow(enhanced, cmap='gray')
axes[1, 1].set_title(f'CLAHE Enhanced\nStd: {enhanced.std():.2f}', fontsize=12, fontweight='bold')
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

## 5. Build the Enhancement Pipeline Function

In [None]:
def enhance_image(img):
    """
    Apply complete enhancement pipeline:
    1. Non-Local Means Denoising
    2. CLAHE contrast enhancement
    3. Normalization
    
    Args:
        img: Input grayscale image (numpy array)
    
    Returns:
        enhanced: Enhanced image (numpy array, uint8, [0, 255])
    """
    # Step 1: Denoise using Non-Local Means
    denoised = cv2.fastNlMeansDenoising(img, None, h=10, templateWindowSize=7, searchWindowSize=21)
    
    # Step 2: Apply CLAHE for contrast enhancement
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(denoised)
    
    # Step 3: Normalize to [0, 255] range
    enhanced = cv2.normalize(enhanced, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    
    return enhanced

# Test the function
result = enhance_image(img)

print(f"✓ Enhancement pipeline created successfully")
print(f"Output shape: {result.shape}")
print(f"Output range: [{result.min()}, {result.max()}]")

In [None]:
# Compare original vs final enhanced
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(img, cmap='gray')
axes[0].set_title('Original', fontsize=14, fontweight='bold')
axes[0].axis('off')

axes[1].imshow(result, cmap='gray')
axes[1].set_title('Enhanced (NLM + CLAHE)', fontsize=14, fontweight='bold')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"Contrast improvement: {((result.std() / img.std() - 1) * 100):.1f}%")

## 6. Test on Multiple Samples

In [None]:
# Test on 3 random samples (one from each class)
fig, axes = plt.subplots(3, 2, figsize=(10, 15))

for i, label in enumerate(['1', '2', '3']):
    # Pick a random image from this class
    img_files = glob.glob(f"../outputs/ce_mri_images/{label}/*.png")
    if img_files:
        sample_path = random.choice(img_files)
        sample_img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
        sample_enhanced = enhance_image(sample_img)
        
        # Plot original
        axes[i, 0].imshow(sample_img, cmap='gray')
        axes[i, 0].set_title(f'Class {label}: Original\n{os.path.basename(sample_path)[:30]}...', fontsize=10)
        axes[i, 0].axis('off')
        
        # Plot enhanced
        axes[i, 1].imshow(sample_enhanced, cmap='gray')
        axes[i, 1].set_title(f'Class {label}: Enhanced', fontsize=10)
        axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

## 7. Quantitative Analysis

In [None]:
# Analyze enhancement metrics on sample images
def compute_metrics(img):
    """Compute image quality metrics"""
    return {
        'mean': img.mean(),
        'std': img.std(),
        'min': img.min(),
        'max': img.max(),
        'contrast': img.max() - img.min()
    }

# Sample 10 random images from each class
metrics_before = []
metrics_after = []

for label in ['1', '2', '3']:
    img_files = glob.glob(f"../outputs/ce_mri_images/{label}/*.png")
    samples = random.sample(img_files, min(10, len(img_files)))
    
    for img_path in samples:
        img_orig = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img_enh = enhance_image(img_orig)
        
        metrics_before.append(compute_metrics(img_orig))
        metrics_after.append(compute_metrics(img_enh))

# Convert to arrays for easy calculation
import pandas as pd
df_before = pd.DataFrame(metrics_before)
df_after = pd.DataFrame(metrics_after)

print("=" * 60)
print("ENHANCEMENT METRICS COMPARISON")
print("=" * 60)
print("\nBEFORE Enhancement:")
print(df_before.describe())
print("\nAFTER Enhancement:")
print(df_after.describe())
print("\nIMPROVEMENT:")
print(f"Mean contrast increase: {((df_after['std'].mean() / df_before['std'].mean() - 1) * 100):.1f}%")

## 8. Ready for Batch Processing

The `enhance_image()` function is now ready to process all 3,064 images.

Next step: Run `src/module1_enhance.py` to create the enhanced dataset in `outputs/ce_mri_enhanced/`