In [1]:
import os
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import random
import shutil
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("OCT IMAGE CLASSIFICATION PROJECT - NOTEBOOK 2: SEGMENTATION")
print("Classical Segmentation (Otsu + Morphology)")
print("="*80)

# ========== CONFIGURATION ==========
DATA_DIR = '/kaggle/input/oct2017/OCT2017'
OUTPUT_DIR = '/kaggle/working'
SEGMENTED_DIR = os.path.join(OUTPUT_DIR, 'segmented_images')
COLORMAP_DIR = os.path.join(OUTPUT_DIR, 'segmented_colormaps')
OVERLAY_DIR = os.path.join(OUTPUT_DIR, 'segmented_overlays')
VISUALIZATION_DIR = os.path.join(OUTPUT_DIR, 'segmentation_vis')

os.makedirs(SEGMENTED_DIR, exist_ok=True)
os.makedirs(COLORMAP_DIR, exist_ok=True)
os.makedirs(OVERLAY_DIR, exist_ok=True)
os.makedirs(VISUALIZATION_DIR, exist_ok=True)

TARGET_SIZE = (224, 224)
NUM_TOTAL = 20000  # Change to 200 for quick test
DEVICE = 'cpu'

print(f"âœ“ Device: {DEVICE}")
print(f"âœ“ Target image size: {TARGET_SIZE}")
print(f"âœ“ Processing {NUM_TOTAL} images total")

# ========== SEGMENTATION FUNCTION ==========
def segment_oct_image(image_np):
    """
    Classical segmentation using Otsu thresholding + morphology
    Returns: segmented image, binary mask
    """
    # Convert to grayscale
    if len(image_np.shape) == 3:
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    else:
        gray = image_np
    
    # Apply Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # Otsu's thresholding - automatically finds best threshold
    _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    # Morphological operations to clean up
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    morph = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)
    morph = cv2.morphologyEx(morph, cv2.MORPH_OPEN, kernel, iterations=1)
    
    # Find contours
    contours, _ = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Create mask from largest contour (the actual retina ROI)
    mask = np.zeros_like(gray)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        cv2.drawContours(mask, [largest_contour], -1, 255, -1)
    else:
        mask = np.ones_like(gray) * 255
    
    # Apply mask to original image
    if len(image_np.shape) == 3:
        segmented = cv2.bitwise_and(image_np, image_np, mask=mask)
    else:
        segmented = cv2.bitwise_and(gray, gray, mask=mask)
    
    return segmented, mask

# ========== COLORMAP FUNCTION ==========
def save_colormap(mask_array, out_path):
    """Convert binary mask to colored visualization"""
    mask_norm = (mask_array / 255.0) if mask_array.max() > 1 else mask_array
    cmap = plt.get_cmap('viridis')
    rgba = cmap(mask_norm)
    rgb = np.uint8(rgba[:,:,:3] * 255)
    Image.fromarray(rgb).save(out_path)

# ========== OVERLAY FUNCTION ==========
def make_overlay(orig_img_array, mask_array, out_path):
    """Create overlay: original + segmentation boundary"""
    overlay = cv2.cvtColor(orig_img_array, cv2.COLOR_GRAY2RGB)
    contours, _ = cv2.findContours(mask_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    overlay = cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
    Image.fromarray(overlay).save(out_path)

# ========== BUILD IMAGE INDEX ==========
print("\n[1/6] Building image index...")
image_index = []
splits = ['train', 'val', 'test']

for split in splits:
    split_path = os.path.join(DATA_DIR, split)
    if not os.path.exists(split_path):
        continue
    for cls in os.listdir(split_path):
        class_path = os.path.join(split_path, cls)
        if not os.path.isdir(class_path):
            continue
        img_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        for fname in img_files:
            image_index.append((split, cls, fname))

print(f"âœ“ Total images found: {len(image_index)}")

# ========== SAMPLE IMAGES ==========
print(f"\n[2/6] Sampling {NUM_TOTAL} random images...")
random.seed(42)
if len(image_index) > NUM_TOTAL:
    image_index = random.sample(image_index, NUM_TOTAL)
print(f"âœ“ Processing {len(image_index)} images")

# ========== RUN SEGMENTATION ==========
print(f"\n[3/6] Running classical segmentation on all images...")
segmentation_stats = []
processed = 0

for (split, cls, fname) in tqdm(image_index, desc="Segmenting"):
    img_path = os.path.join(DATA_DIR, split, cls, fname)
    
    # Load image
    image = Image.open(img_path).convert('RGB')
    image_np = np.array(image)
    
    # Resize
    image_resized = cv2.resize(image_np, TARGET_SIZE)
    
    # Segment
    segmented, mask = segment_oct_image(image_resized)
    
    # Create output directories
    seg_cls_dir = os.path.join(SEGMENTED_DIR, split, cls)
    color_cls_dir = os.path.join(COLORMAP_DIR, split, cls)
    overlay_cls_dir = os.path.join(OVERLAY_DIR, split, cls)
    
    os.makedirs(seg_cls_dir, exist_ok=True)
    os.makedirs(color_cls_dir, exist_ok=True)
    os.makedirs(overlay_cls_dir, exist_ok=True)
    
    # Save segmented image
    img_name = os.path.basename(img_path)
    seg_output_path = os.path.join(seg_cls_dir, img_name)
    Image.fromarray(segmented).save(seg_output_path)
    
    # Save colormap
    colormap_path = os.path.join(color_cls_dir, img_name.replace('.', '_colormap.'))
    save_colormap(mask, colormap_path)
    
    # Save overlay
    overlay_path = os.path.join(overlay_cls_dir, img_name.replace('.', '_overlay.'))
    make_overlay(cv2.cvtColor(image_resized, cv2.COLOR_RGB2GRAY), mask, overlay_path)
    
    # Track stats
    total_pixels = mask.size
    foreground_pixels = np.sum(mask > 0)
    foreground_ratio = foreground_pixels / total_pixels
    
    segmentation_stats.append({
        'split': split,
        'class': cls,
        'original_path': img_path,
        'segmented_path': seg_output_path,
        'foreground_ratio': foreground_ratio
    })
    
    processed += 1

print(f"âœ“ Segmentation complete! Processed {processed} images")

# ========== SAVE STATISTICS ==========
print(f"\n[4/6] Computing and saving segmentation statistics...")
df_stats = pd.DataFrame(segmentation_stats)
df_stats.to_csv(os.path.join(OUTPUT_DIR, 'segmentation_mapping.csv'), index=False)
print(f"âœ“ Saved segmentation mapping")

# Print summary
print("\n" + "="*60)
print("SEGMENTATION QUALITY METRICS")
print("="*60)
print(df_stats.groupby('class')['foreground_ratio'].describe())

# ========== VISUALIZATIONS ==========
print(f"\n[5/6] Creating visualizations...")

# Sample visualization
fig, axes = plt.subplots(5, 3, figsize=(15, 20))
fig.suptitle('Segmentation Results: Original vs Mask vs Overlay', fontsize=16, fontweight='bold')

sample_indices = np.random.choice(len(image_index), min(5, len(image_index)), replace=False)

for idx, sample_idx in enumerate(sample_indices):
    split, cls, fname = image_index[sample_idx]
    img_path = os.path.join(DATA_DIR, split, cls, fname)
    image = Image.open(img_path).convert('RGB')
    image_np = np.array(image)
    image_resized = cv2.resize(image_np, TARGET_SIZE)
    segmented, mask = segment_oct_image(image_resized)
    
    # Original
    axes[idx, 0].imshow(image_resized, cmap='gray')
    axes[idx, 0].set_title(f'Original - {cls}', fontsize=10, fontweight='bold')
    axes[idx, 0].axis('off')
    
    # Mask
    axes[idx, 1].imshow(mask, cmap='gray')
    axes[idx, 1].set_title('Segmentation Mask', fontsize=10)
    axes[idx, 1].axis('off')
    
    # Segmented
    axes[idx, 2].imshow(segmented, cmap='gray')
    axes[idx, 2].set_title('Segmented ROI', fontsize=10)
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(VISUALIZATION_DIR, 'segmentation_comparison.png'), dpi=300, bbox_inches='tight')
print(f"âœ“ Saved segmentation comparison visualization")
plt.close()

# ========== VERIFY DATASET ==========
print(f"\n[6/6] Verifying segmented dataset...")

verification = {}
for split in ['train', 'val', 'test']:
    split_path = os.path.join(SEGMENTED_DIR, split)
    if not os.path.exists(split_path):
        continue
    verification[split] = {}
    classes = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
    for cls in classes:
        cls_path = os.path.join(split_path, cls)
        images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        verification[split][cls] = len(images)

print("\n" + "="*60)
print("SEGMENTED DATASET VERIFICATION")
print("="*60)
for split, classes in verification.items():
    print(f"\n{split.upper()}:")
    for cls, count in classes.items():
        print(f"  {cls}: {count} images")

# ========== CREATE SUMMARY REPORT ==========
print("\n" + "="*60)
print("Creating summary report...")
print("="*60)

summary_report = f"""
{'='*80}
SEGMENTATION PREPROCESSING - SUMMARY REPORT
{'='*80}

1. DATASET INFORMATION
   - Total images processed: {len(df_stats)}
   - Train split: {len(df_stats[df_stats['split']=='train'])} images
   - Validation split: {len(df_stats[df_stats['split']=='val'])} images
   - Test split: {len(df_stats[df_stats['split']=='test'])} images

2. SEGMENTATION METHOD
   - Algorithm: Otsu Thresholding + Morphological Operations
   - Target size: {TARGET_SIZE}
   - Processing: Gaussian Blur â†’ Otsu â†’ Morphology â†’ Largest Contour â†’ ROI Extraction

3. OUTPUT LOCATION
   - Segmented images: {SEGMENTED_DIR}
   - Colorized masks: {COLORMAP_DIR}
   - Overlay visualizations: {OVERLAY_DIR}
   - Mapping file: segmentation_mapping.csv

4. QUALITY METRICS (Average across all classes)
   - Mean foreground ratio: {df_stats['foreground_ratio'].mean():.3f}
   - Std foreground ratio: {df_stats['foreground_ratio'].std():.3f}
   - Min foreground ratio: {df_stats['foreground_ratio'].min():.3f}
   - Max foreground ratio: {df_stats['foreground_ratio'].max():.3f}

5. METHOD ADVANTAGES
   âœ“ Fast processing (no training required)
   âœ“ Reproducible and interpretable
   âœ“ Well-suited for OCT images with clear tissue/background contrast
   âœ“ Robust to image variations

6. NEXT STEPS
   - Proceed to Notebook 3 for data augmentation and preprocessing
   - Use segmented images as input for your classification models
   - All outputs ready for downstream ML pipelines

{'='*80}
"""

print(summary_report)

with open(os.path.join(OUTPUT_DIR, 'segmentation_report.txt'), 'w') as f:
    f.write(summary_report)

# ========== ZIP OUTPUTS ==========
print("\nZipping outputs for download...")
shutil.make_archive(os.path.join(OUTPUT_DIR, 'segmented_images'), 'zip', SEGMENTED_DIR)
print("âœ“ Segmented images zipped: segmented_images.zip")

shutil.make_archive(os.path.join(OUTPUT_DIR, 'segmented_colormaps'), 'zip', COLORMAP_DIR)
print("âœ“ Colorized masks zipped: segmented_colormaps.zip")

shutil.make_archive(os.path.join(OUTPUT_DIR, 'segmented_overlays'), 'zip', OVERLAY_DIR)
print("âœ“ Overlays zipped: segmented_overlays.zip")

# ========== FINAL SUMMARY ==========
print("\n" + "="*80)
print("NOTEBOOK 2 COMPLETED SUCCESSFULLY!")
print("="*80)
print(f"\nâœ“ Segmented images saved to: {SEGMENTED_DIR}")
print(f"âœ“ Total images processed: {len(df_stats)}")
print(f"\nGenerated output files:")
print(f"  1. segmented_images.zip         (ROI-masked images)")
print(f"  2. segmented_colormaps.zip      (Colored mask visualizations)")
print(f"  3. segmented_overlays.zip       (Original + mask overlays)")
print(f"  4. segmentation_mapping.csv     (Image-to-output mapping)")
print(f"  5. segmentation_report.txt      (Summary report)")
print(f"  6. segmentation_comparison.png  (Visual comparison)")
print(f"\nðŸ“Œ NEXT STEP: Run Notebook 3 for Preprocessing & Augmentation")
print("="*80)


OCT IMAGE CLASSIFICATION PROJECT - NOTEBOOK 2: SEGMENTATION
Classical Segmentation (Otsu + Morphology)
âœ“ Device: cpu
âœ“ Target image size: (224, 224)
âœ“ Processing 20000 images total

[1/6] Building image index...
âœ“ Total images found: 84484

[2/6] Sampling 20000 random images...
âœ“ Processing 20000 images

[3/6] Running classical segmentation on all images...


Segmenting: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 20000/20000 [04:33<00:00, 73.25it/s]


âœ“ Segmentation complete! Processed 20000 images

[4/6] Computing and saving segmentation statistics...
âœ“ Saved segmentation mapping

SEGMENTATION QUALITY METRICS
         count      mean       std       min       25%       50%       75%  \
class                                                                        
CNV     8902.0  0.181105  0.082741  0.008849  0.107825  0.200943  0.241684   
DME     2738.0  0.164061  0.082225  0.013194  0.088394  0.177087  0.226144   
DRUSEN  2057.0  0.155218  0.078921  0.012297  0.078444  0.170839  0.214505   
NORMAL  6303.0  0.145180  0.072294  0.013393  0.075096  0.160176  0.200155   

             max  
class             
CNV     0.718212  
DME     0.573202  
DRUSEN  0.664740  
NORMAL  0.562938  

[5/6] Creating visualizations...
âœ“ Saved segmentation comparison visualization

[6/6] Verifying segmented dataset...

SEGMENTED DATASET VERIFICATION

TRAIN:
  DME: 2686 images
  CNV: 8847 images
  NORMAL: 6242 images
  DRUSEN: 1995 images

VAL:
  C