# Comprehensive Exploratory Data Analysis (EDA)
## Breast Cancer Detection using BreakHis Dataset

This notebook provides a comprehensive analysis of the BreakHis dataset for breast cancer histopathological image classification.

### Dataset Overview
- **Dataset**: BreakHis (Breast Cancer Histopathological Database)
- **Classes**: 8 total (4 benign + 4 malignant)
- **Magnifications**: 40X, 100X, 200X, 400X
- **Image Format**: PNG
- **Total Images**: ~7,900

### Analysis Structure
1. Data Loading and Basic Statistics
2. Class Distribution Analysis
3. Magnification Analysis
4. Patient-wise Analysis
5. Image Quality Assessment
6. Visual Analysis
7. Data Imbalance Analysis
8. Recommendations for Training

In [None]:
# Import necessary libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
from PIL import Image
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Configuration
DATASET_ROOT = "../data/breakhis/BreaKHis_v1/BreaKHis_v1/histology_slides/breast"
RANDOM_STATE = 42

print("📊 Starting Comprehensive EDA for Breast Cancer Detection")
print(f"📁 Dataset Path: {DATASET_ROOT}")

ModuleNotFoundError: No module named 'cv2'

## 1. Data Loading and Metadata Creation

In [7]:
def create_comprehensive_metadata(dataset_root):
    """
    Create comprehensive metadata from BreakHis dataset
    """
    if not os.path.exists(dataset_root):
        raise FileNotFoundError(f"Dataset path not found: {dataset_root}")
    
    # Get all image paths
    image_paths = glob(os.path.join(dataset_root, "*", "*", "*", "*", "*", "*.png"))
    print(f"🔍 Found {len(image_paths)} images")
    
    data = []
    
    for path in image_paths:
        parts = path.split(os.sep)
        try:
            # Extract metadata from path structure
            label_type = parts[-6]         # 'malignant' or 'benign'
            subclass = parts[-4]           # e.g. 'ductal_carcinoma'
            magnification = parts[-2]      # e.g. '100X'
            filename = os.path.basename(path)
            
            # Extract patient ID from filename
            # Format: SOB_B_A-14-22549AB-40-001.png
            filename_parts = filename.split('-')
            if len(filename_parts) >= 3:
                patient_id = filename_parts[2]
            else:
                patient_id = "unknown"
            
            # Extract slide number
            slide_num = filename.split('-')[-1].replace('.png', '') if '-' in filename else "001"
            
            data.append({
                "path": path,
                "filename": filename,
                "label_type": label_type,
                "subclass": subclass,
                "magnification": magnification,
                "patient_id": patient_id,
                "slide_number": slide_num
            })
            
        except IndexError as e:
            print(f"⚠️ Skipping malformed path: {path}")
            continue
    
    metadata_df = pd.DataFrame(data)
    
    # Add derived features
    metadata_df['magnification_numeric'] = metadata_df['magnification'].str.replace('X', '').astype(int)
    metadata_df['is_malignant'] = (metadata_df['label_type'] == 'malignant').astype(int)
    
    return metadata_df

# Create metadata
metadata = create_comprehensive_metadata(DATASET_ROOT)
print(f"✅ Created metadata for {len(metadata)} images")
print(f"📋 Columns: {list(metadata.columns)}")

NameError: name 'DATASET_ROOT' is not defined

In [None]:
# Display basic information
print("📊 DATASET OVERVIEW")
print("=" * 50)
print(f"Total Images: {len(metadata):,}")
print(f"Unique Patients: {metadata['patient_id'].nunique():,}")
print(f"Unique Subclasses: {metadata['subclass'].nunique()}")
print(f"Magnification Levels: {sorted(metadata['magnification'].unique())}")
print(f"Label Types: {metadata['label_type'].unique()}")

# Display first few rows
print("\n📋 Sample Data:")
metadata.head()

## 2. Class Distribution Analysis

In [None]:
# Overall class distribution
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Benign vs Malignant
label_counts = metadata['label_type'].value_counts()
axes[0, 0].pie(label_counts.values, labels=label_counts.index, autopct='%1.1f%%', startangle=90)
axes[0, 0].set_title('Benign vs Malignant Distribution', fontsize=14, fontweight='bold')

# 2. Subclass distribution
subclass_counts = metadata['subclass'].value_counts()
axes[0, 1].barh(range(len(subclass_counts)), subclass_counts.values)
axes[0, 1].set_yticks(range(len(subclass_counts)))
axes[0, 1].set_yticklabels([label.replace('_', ' ').title() for label in subclass_counts.index])
axes[0, 1].set_xlabel('Number of Images')
axes[0, 1].set_title('Subclass Distribution', fontsize=14, fontweight='bold')

# Add value labels on bars
for i, v in enumerate(subclass_counts.values):
    axes[0, 1].text(v + 50, i, str(v), va='center', fontweight='bold')

# 3. Magnification distribution
mag_counts = metadata['magnification'].value_counts().sort_index()
axes[1, 0].bar(mag_counts.index, mag_counts.values, color='skyblue', edgecolor='navy')
axes[1, 0].set_xlabel('Magnification Level')
axes[1, 0].set_ylabel('Number of Images')
axes[1, 0].set_title('Magnification Level Distribution', fontsize=14, fontweight='bold')

# Add value labels on bars
for i, v in enumerate(mag_counts.values):
    axes[1, 0].text(i, v + 30, str(v), ha='center', fontweight='bold')

# 4. Images per patient distribution
patient_counts = metadata['patient_id'].value_counts()
axes[1, 1].hist(patient_counts.values, bins=20, color='lightcoral', edgecolor='darkred', alpha=0.7)
axes[1, 1].set_xlabel('Images per Patient')
axes[1, 1].set_ylabel('Number of Patients')
axes[1, 1].set_title('Images per Patient Distribution', fontsize=14, fontweight='bold')
axes[1, 1].axvline(patient_counts.mean(), color='red', linestyle='--', 
                   label=f'Mean: {patient_counts.mean():.1f}')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

# Print detailed statistics
print("\n📊 DETAILED CLASS STATISTICS")
print("=" * 50)
for label_type in metadata['label_type'].unique():
    subset = metadata[metadata['label_type'] == label_type]
    print(f"\n{label_type.upper()}:")
    print(f"  Total Images: {len(subset):,}")
    print(f"  Unique Patients: {subset['patient_id'].nunique()}")
    print(f"  Subclasses: {subset['subclass'].nunique()}")
    print(f"  Subclass breakdown:")
    for subclass, count in subset['subclass'].value_counts().items():
        print(f"    - {subclass.replace('_', ' ').title()}: {count:,} images")

## 3. Cross-tabulation Analysis

In [None]:
# Create cross-tabulation matrices
fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# 1. Subclass vs Magnification
crosstab_mag = pd.crosstab(metadata['subclass'], metadata['magnification'])
sns.heatmap(crosstab_mag, annot=True, fmt='d', cmap='Blues', ax=axes[0])
axes[0].set_title('Subclass vs Magnification Cross-tabulation', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Subclass')
axes[0].set_xlabel('Magnification')

# Format y-axis labels
axes[0].set_yticklabels([label.get_text().replace('_', ' ').title() for label in axes[0].get_yticklabels()], 
                       rotation=0)

# 2. Label Type vs Magnification
crosstab_label = pd.crosstab(metadata['label_type'], metadata['magnification'])
sns.heatmap(crosstab_label, annot=True, fmt='d', cmap='Reds', ax=axes[1])
axes[1].set_title('Label Type vs Magnification Cross-tabulation', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Label Type')
axes[1].set_xlabel('Magnification')

plt.tight_layout()
plt.show()

# Print cross-tabulation statistics
print("\n📊 CROSS-TABULATION ANALYSIS")
print("=" * 50)
print("\nSubclass vs Magnification:")
print(crosstab_mag)
print("\nLabel Type vs Magnification:")
print(crosstab_label)

## 4. Patient-wise Analysis

In [None]:
# Patient-wise analysis
patient_analysis = metadata.groupby('patient_id').agg({
    'path': 'count',
    'label_type': lambda x: x.iloc[0],  # Assuming all images from same patient have same label
    'subclass': lambda x: x.iloc[0],
    'magnification': lambda x: list(x.unique())
}).rename(columns={'path': 'image_count'})

patient_analysis['magnification_count'] = patient_analysis['magnification'].apply(len)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Patient distribution by label type
patient_label_counts = patient_analysis['label_type'].value_counts()
axes[0, 0].pie(patient_label_counts.values, labels=patient_label_counts.index, 
               autopct='%1.1f%%', startangle=90)
axes[0, 0].set_title('Patient Distribution: Benign vs Malignant', fontsize=14, fontweight='bold')

# 2. Images per patient by label type
benign_patients = patient_analysis[patient_analysis['label_type'] == 'benign']['image_count']
malignant_patients = patient_analysis[patient_analysis['label_type'] == 'malignant']['image_count']

axes[0, 1].hist([benign_patients, malignant_patients], bins=15, alpha=0.7, 
                label=['Benign', 'Malignant'], color=['green', 'red'])
axes[0, 1].set_xlabel('Images per Patient')
axes[0, 1].set_ylabel('Number of Patients')
axes[0, 1].set_title('Images per Patient by Label Type', fontsize=14, fontweight='bold')
axes[0, 1].legend()

# 3. Magnification coverage per patient
mag_coverage = patient_analysis['magnification_count'].value_counts().sort_index()
axes[1, 0].bar(mag_coverage.index, mag_coverage.values, color='orange', edgecolor='darkorange')
axes[1, 0].set_xlabel('Number of Different Magnifications')
axes[1, 0].set_ylabel('Number of Patients')
axes[1, 0].set_title('Magnification Coverage per Patient', fontsize=14, fontweight='bold')

# Add value labels
for i, v in enumerate(mag_coverage.values):
    axes[1, 0].text(mag_coverage.index[i], v + 1, str(v), ha='center', fontweight='bold')

# 4. Patient distribution by subclass
patient_subclass = patient_analysis['subclass'].value_counts()
axes[1, 1].barh(range(len(patient_subclass)), patient_subclass.values)
axes[1, 1].set_yticks(range(len(patient_subclass)))
axes[1, 1].set_yticklabels([label.replace('_', ' ').title() for label in patient_subclass.index])
axes[1, 1].set_xlabel('Number of Patients')
axes[1, 1].set_title('Patient Distribution by Subclass', fontsize=14, fontweight='bold')

# Add value labels
for i, v in enumerate(patient_subclass.values):
    axes[1, 1].text(v + 0.5, i, str(v), va='center', fontweight='bold')

plt.tight_layout()
plt.show()

# Print patient statistics
print("\n👥 PATIENT-WISE STATISTICS")
print("=" * 50)
print(f"Total Unique Patients: {len(patient_analysis)}")
print(f"Benign Patients: {len(patient_analysis[patient_analysis['label_type'] == 'benign'])}")
print(f"Malignant Patients: {len(patient_analysis[patient_analysis['label_type'] == 'malignant'])}")
print(f"\nImages per Patient Statistics:")
print(f"  Mean: {patient_analysis['image_count'].mean():.2f}")
print(f"  Median: {patient_analysis['image_count'].median():.2f}")
print(f"  Min: {patient_analysis['image_count'].min()}")
print(f"  Max: {patient_analysis['image_count'].max()}")
print(f"  Std: {patient_analysis['image_count'].std():.2f}")

## 5. Image Quality and Properties Analysis

In [None]:
def analyze_image_properties(metadata_sample, sample_size=100):
    """
    Analyze image properties from a sample of images
    """
    # Sample images for analysis (to avoid processing all images)
    sample_metadata = metadata_sample.sample(n=min(sample_size, len(metadata_sample)), 
                                            random_state=RANDOM_STATE)
    
    image_properties = []
    
    print(f"🔍 Analyzing {len(sample_metadata)} sample images...")
    
    for idx, row in sample_metadata.iterrows():
        try:
            # Load image
            img = Image.open(row['path'])
            img_array = np.array(img)
            
            # Calculate properties
            properties = {
                'width': img.width,
                'height': img.height,
                'channels': len(img_array.shape) if len(img_array.shape) == 2 else img_array.shape[2],
                'file_size_kb': os.path.getsize(row['path']) / 1024,
                'mean_intensity': np.mean(img_array),
                'std_intensity': np.std(img_array),
                'min_intensity': np.min(img_array),
                'max_intensity': np.max(img_array),
                'subclass': row['subclass'],
                'magnification': row['magnification'],
                'label_type': row['label_type']
            }
            
            image_properties.append(properties)
            
        except Exception as e:
            print(f"⚠️ Error processing {row['path']}: {e}")
            continue
    
    return pd.DataFrame(image_properties)

# Analyze image properties
image_props = analyze_image_properties(metadata, sample_size=200)
print(f"✅ Analyzed {len(image_props)} images")

In [None]:
# Visualize image properties
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Image dimensions
axes[0, 0].scatter(image_props['width'], image_props['height'], alpha=0.6, c='blue')
axes[0, 0].set_xlabel('Width (pixels)')
axes[0, 0].set_ylabel('Height (pixels)')
axes[0, 0].set_title('Image Dimensions Distribution', fontsize=12, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)

# 2. File size distribution
axes[0, 1].hist(image_props['file_size_kb'], bins=20, color='green', alpha=0.7, edgecolor='darkgreen')
axes[0, 1].set_xlabel('File Size (KB)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('File Size Distribution', fontsize=12, fontweight='bold')
axes[0, 1].axvline(image_props['file_size_kb'].mean(), color='red', linestyle='--', 
                   label=f'Mean: {image_props["file_size_kb"].mean():.1f} KB')
axes[0, 1].legend()

# 3. Mean intensity by label type
sns.boxplot(data=image_props, x='label_type', y='mean_intensity', ax=axes[0, 2])
axes[0, 2].set_title('Mean Intensity by Label Type', fontsize=12, fontweight='bold')
axes[0, 2].set_ylabel('Mean Intensity')

# 4. Mean intensity by magnification
sns.boxplot(data=image_props, x='magnification', y='mean_intensity', ax=axes[1, 0])
axes[1, 0].set_title('Mean Intensity by Magnification', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('Mean Intensity')

# 5. Standard deviation of intensity
axes[1, 1].hist(image_props['std_intensity'], bins=20, color='orange', alpha=0.7, edgecolor='darkorange')
axes[1, 1].set_xlabel('Standard Deviation of Intensity')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('Intensity Variation Distribution', fontsize=12, fontweight='bold')

# 6. File size by magnification
sns.boxplot(data=image_props, x='magnification', y='file_size_kb', ax=axes[1, 2])
axes[1, 2].set_title('File Size by Magnification', fontsize=12, fontweight='bold')
axes[1, 2].set_ylabel('File Size (KB)')

plt.tight_layout()
plt.show()

# Print image property statistics
print("\n🖼️ IMAGE PROPERTY STATISTICS")
print("=" * 50)
print(f"Image Dimensions:")
print(f"  Width - Mean: {image_props['width'].mean():.0f}, Std: {image_props['width'].std():.0f}")
print(f"  Height - Mean: {image_props['height'].mean():.0f}, Std: {image_props['height'].std():.0f}")
print(f"\nFile Size:")
print(f"  Mean: {image_props['file_size_kb'].mean():.2f} KB")
print(f"  Range: {image_props['file_size_kb'].min():.2f} - {image_props['file_size_kb'].max():.2f} KB")
print(f"\nIntensity Statistics:")
print(f"  Mean Intensity: {image_props['mean_intensity'].mean():.2f} ± {image_props['mean_intensity'].std():.2f}")
print(f"  Intensity Range: {image_props['min_intensity'].min()} - {image_props['max_intensity'].max()}")

## 6. Visual Sample Analysis

In [None]:
def display_sample_images(metadata, samples_per_class=2, magnification='100X'):
    """
    Display sample images from each subclass
    """
    # Filter by magnification for consistency
    mag_data = metadata[metadata['magnification'] == magnification]
    
    # Get unique subclasses
    subclasses = mag_data['subclass'].unique()
    
    # Calculate grid size
    n_classes = len(subclasses)
    n_cols = samples_per_class
    n_rows = n_classes
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))
    
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    if n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    for i, subclass in enumerate(subclasses):
        # Get samples for this subclass
        class_data = mag_data[mag_data['subclass'] == subclass]
        samples = class_data.sample(n=min(samples_per_class, len(class_data)), 
                                   random_state=RANDOM_STATE)
        
        for j, (_, row) in enumerate(samples.iterrows()):
            if j >= n_cols:
                break
                
            try:
                # Load and display image
                img = Image.open(row['path'])
                axes[i, j].imshow(img)
                axes[i, j].axis('off')
                
                # Set title
                title = f"{subclass.replace('_', ' ').title()}\n{row['label_type'].title()}"
                axes[i, j].set_title(title, fontsize=10, fontweight='bold')
                
            except Exception as e:
                axes[i, j].text(0.5, 0.5, f'Error loading\n{row["filename"]}', 
                               ha='center', va='center', transform=axes[i, j].transAxes)
                axes[i, j].axis('off')
    
    plt.suptitle(f'Sample Images from Each Subclass (Magnification: {magnification})', 
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.show()

# Display sample images
display_sample_images(metadata, samples_per_class=3, magnification='100X')

## 7. Data Imbalance Analysis

In [None]:
def calculate_imbalance_metrics(metadata):
    """
    Calculate various imbalance metrics
    """
    # Class distribution
    class_counts = metadata['subclass'].value_counts()
    
    # Calculate imbalance ratio
    max_class = class_counts.max()
    min_class = class_counts.min()
    imbalance_ratio = max_class / min_class
    
    # Calculate class weights (inverse frequency)
    total_samples = len(metadata)
    n_classes = len(class_counts)
    class_weights = {}
    
    for class_name, count in class_counts.items():
        weight = total_samples / (n_classes * count)
        class_weights[class_name] = weight
    
    return {
        'class_counts': class_counts,
        'imbalance_ratio': imbalance_ratio,
        'class_weights': class_weights,
        'total_samples': total_samples,
        'n_classes': n_classes
    }

# Calculate imbalance metrics
imbalance_metrics = calculate_imbalance_metrics(metadata)

# Visualize imbalance
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# 1. Class distribution (sorted)
class_counts = imbalance_metrics['class_counts']
axes[0].bar(range(len(class_counts)), class_counts.values, color='lightblue', edgecolor='navy')
axes[0].set_xticks(range(len(class_counts)))
axes[0].set_xticklabels([label.replace('_', '\n').title() for label in class_counts.index], 
                       rotation=45, ha='right')
axes[0].set_ylabel('Number of Images')
axes[0].set_title('Class Distribution (Sorted by Count)', fontsize=14, fontweight='bold')

# Add value labels
for i, v in enumerate(class_counts.values):
    axes[0].text(i, v + 50, str(v), ha='center', fontweight='bold')

# 2. Class weights
class_weights = imbalance_metrics['class_weights']
weights_sorted = dict(sorted(class_weights.items(), key=lambda x: x[1], reverse=True))
axes[1].bar(range(len(weights_sorted)), list(weights_sorted.values()), 
           color='lightcoral', edgecolor='darkred')
axes[1].set_xticks(range(len(weights_sorted)))
axes[1].set_xticklabels([label.replace('_', '\n').title() for label in weights_sorted.keys()], 
                       rotation=45, ha='right')
axes[1].set_ylabel('Class Weight')
axes[1].set_title('Calculated Class Weights', fontsize=14, fontweight='bold')

# 3. Imbalance visualization
normalized_counts = class_counts / class_counts.sum() * 100
axes[2].pie(normalized_counts.values, labels=[label.replace('_', ' ').title() for label in normalized_counts.index], 
           autopct='%1.1f%%', startangle=90)
axes[2].set_title('Class Distribution (Percentage)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Print imbalance statistics
print("\n⚖️ DATA IMBALANCE ANALYSIS")
print("=" * 50)
print(f"Imbalance Ratio (Max/Min): {imbalance_metrics['imbalance_ratio']:.2f}")
print(f"Most Common Class: {class_counts.index[0]} ({class_counts.iloc[0]:,} images)")
print(f"Least Common Class: {class_counts.index[-1]} ({class_counts.iloc[-1]:,} images)")
print(f"\nClass Distribution:")
for class_name, count in class_counts.items():
    percentage = (count / imbalance_metrics['total_samples']) * 100
    print(f"  {class_name.replace('_', ' ').title()}: {count:,} ({percentage:.1f}%)")

print(f"\nRecommended Class Weights:")
for class_name, weight in sorted(class_weights.items(), key=lambda x: x[1], reverse=True):
    print(f"  {class_name.replace('_', ' ').title()}: {weight:.3f}")

## 8. Train/Validation/Test Split Analysis

In [None]:
from sklearn.model_selection import train_test_split

def create_patient_wise_splits(metadata, test_size=0.15, val_size=0.15, random_state=42):
    """
    Create patient-wise stratified splits to avoid data leakage
    """
    # Get unique patients with their labels
    patient_labels = metadata.groupby('patient_id')['subclass'].first().reset_index()
    
    # First split: train+val vs test
    train_val_patients, test_patients = train_test_split(
        patient_labels, test_size=test_size, 
        stratify=patient_labels['subclass'], random_state=random_state
    )
    
    # Second split: train vs val
    val_size_adjusted = val_size / (1 - test_size)
    train_patients, val_patients = train_test_split(
        train_val_patients, test_size=val_size_adjusted,
        stratify=train_val_patients['subclass'], random_state=random_state
    )
    
    # Map back to full metadata
    train_data = metadata[metadata['patient_id'].isin(train_patients['patient_id'])]
    val_data = metadata[metadata['patient_id'].isin(val_patients['patient_id'])]
    test_data = metadata[metadata['patient_id'].isin(test_patients['patient_id'])]
    
    return train_data, val_data, test_data

# Create splits
train_data, val_data, test_data = create_patient_wise_splits(metadata)

# Analyze splits
splits_info = {
    'Train': train_data,
    'Validation': val_data,
    'Test': test_data
}

# Visualize splits
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Split sizes
split_sizes = [len(data) for data in splits_info.values()]
split_names = list(splits_info.keys())
axes[0, 0].pie(split_sizes, labels=split_names, autopct='%1.1f%%', startangle=90)
axes[0, 0].set_title('Dataset Split Distribution', fontsize=14, fontweight='bold')

# 2. Class distribution across splits
split_class_data = []
for split_name, data in splits_info.items():
    for subclass in data['subclass'].unique():
        count = len(data[data['subclass'] == subclass])
        split_class_data.append({
            'Split': split_name,
            'Subclass': subclass.replace('_', ' ').title(),
            'Count': count
        })

split_class_df = pd.DataFrame(split_class_data)
pivot_data = split_class_df.pivot(index='Subclass', columns='Split', values='Count')
pivot_data.plot(kind='bar', ax=axes[0, 1], width=0.8)
axes[0, 1].set_title('Class Distribution Across Splits', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Subclass')
axes[0, 1].set_ylabel('Number of Images')
axes[0, 1].legend(title='Split')
axes[0, 1].tick_params(axis='x', rotation=45)

# 3. Patient distribution across splits
patient_splits = []
for split_name, data in splits_info.items():
    n_patients = data['patient_id'].nunique()
    patient_splits.append(n_patients)

axes[1, 0].bar(split_names, patient_splits, color=['skyblue', 'lightgreen', 'lightcoral'])
axes[1, 0].set_ylabel('Number of Patients')
axes[1, 0].set_title('Patient Distribution Across Splits', fontsize=14, fontweight='bold')

# Add value labels
for i, v in enumerate(patient_splits):
    axes[1, 0].text(i, v + 1, str(v), ha='center', fontweight='bold')

# 4. Magnification distribution across splits
mag_split_data = []
for split_name, data in splits_info.items():
    for mag in data['magnification'].unique():
        count = len(data[data['magnification'] == mag])
        mag_split_data.append({
            'Split': split_name,
            'Magnification': mag,
            'Count': count
        })

mag_split_df = pd.DataFrame(mag_split_data)
mag_pivot = mag_split_df.pivot(index='Magnification', columns='Split', values='Count')
mag_pivot.plot(kind='bar', ax=axes[1, 1], width=0.8)
axes[1, 1].set_title('Magnification Distribution Across Splits', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Magnification')
axes[1, 1].set_ylabel('Number of Images')
axes[1, 1].legend(title='Split')
axes[1, 1].tick_params(axis='x', rotation=0)

plt.tight_layout()
plt.show()

# Print split statistics
print("\n📊 TRAIN/VALIDATION/TEST SPLIT ANALYSIS")
print("=" * 60)
for split_name, data in splits_info.items():
    print(f"\n{split_name.upper()} SET:")
    print(f"  Images: {len(data):,}")
    print(f"  Patients: {data['patient_id'].nunique()}")
    print(f"  Percentage: {(len(data)/len(metadata)*100):.1f}%")
    print(f"  Class distribution:")
    for subclass, count in data['subclass'].value_counts().items():
        percentage = (count / len(data)) * 100
        print(f"    - {subclass.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")

## 9. Recommendations and Summary

In [None]:
# Generate comprehensive summary and recommendations
print("\n" + "="*80)
print("📋 COMPREHENSIVE EDA SUMMARY & RECOMMENDATIONS")
print("="*80)

print("\n🔍 DATASET OVERVIEW:")
print(f"  • Total Images: {len(metadata):,}")
print(f"  • Unique Patients: {metadata['patient_id'].nunique()}")
print(f"  • Classes: {metadata['subclass'].nunique()} subclasses (4 benign + 4 malignant)")
print(f"  • Magnifications: {len(metadata['magnification'].unique())} levels (40X, 100X, 200X, 400X)")

print("\n⚖️ DATA IMBALANCE INSIGHTS:")
print(f"  • Imbalance Ratio: {imbalance_metrics['imbalance_ratio']:.2f}:1")
print(f"  • Most Common: {class_counts.index[0].replace('_', ' ').title()} ({class_counts.iloc[0]:,} images)")
print(f"  • Least Common: {class_counts.index[-1].replace('_', ' ').title()} ({class_counts.iloc[-1]:,} images)")
print(f"  • Recommendation: Use weighted sampling or class weights during training")

print("\n👥 PATIENT-WISE ANALYSIS:")
print(f"  • Images per Patient: {patient_analysis['image_count'].mean():.1f} ± {patient_analysis['image_count'].std():.1f}")
print(f"  • Patient Distribution: {len(patient_analysis[patient_analysis['label_type'] == 'benign'])} benign, {len(patient_analysis[patient_analysis['label_type'] == 'malignant'])} malignant")
print(f"  • Recommendation: Use patient-wise splits to avoid data leakage")

print("\n🖼️ IMAGE PROPERTIES:")
if len(image_props) > 0:
    print(f"  • Average Dimensions: {image_props['width'].mean():.0f} x {image_props['height'].mean():.0f} pixels")
    print(f"  • Average File Size: {image_props['file_size_kb'].mean():.1f} KB")
    print(f"  • Intensity Range: {image_props['min_intensity'].min()} - {image_props['max_intensity'].max()}")
    print(f"  • Recommendation: Resize to 224x224 for EfficientNet, normalize with ImageNet stats")

print("\n📊 TRAINING RECOMMENDATIONS:")
print("  1. DATA PREPROCESSING:")
print("     • Resize images to 224x224 pixels")
print("     • Normalize with ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])")
print("     • Apply data augmentation (rotation, flip, color jitter)")

print("\n  2. DATA SPLITTING:")
print("     • Use patient-wise stratified splits (70% train, 15% val, 15% test)")
print("     • Ensure no patient appears in multiple splits")
print("     • Maintain class distribution across splits")

print("\n  3. CLASS IMBALANCE HANDLING:")
print("     • Use WeightedRandomSampler during training")
print("     • Apply class weights in loss function")
print("     • Consider focal loss for severe imbalance")

print("\n  4. MODEL ARCHITECTURE:")
print("     • Start with EfficientNetB0 (good balance of accuracy and efficiency)")
print("     • Use transfer learning with ImageNet pretrained weights")
print("     • Fine-tune the entire network with lower learning rate")

print("\n  5. TRAINING STRATEGY:")
print("     • Batch size: 32-64 (depending on GPU memory)")
print("     • Learning rate: 1e-4 with cosine annealing")
print("     • Early stopping based on validation accuracy")
print("     • Save best model based on validation performance")

print("\n  6. EVALUATION METRICS:")
print("     • Accuracy, Precision, Recall, F1-score (per class and macro/micro avg)")
print("     • Confusion matrix analysis")
print("     • ROC-AUC curves for each class")
print("     • Per-magnification performance analysis")

print("\n🚀 NEXT STEPS:")
print("  1. Implement the recommended preprocessing pipeline")
print("  2. Create patient-wise stratified splits")
print("  3. Train EfficientNetB0 baseline with class balancing")
print("  4. Evaluate performance and analyze failure cases")
print("  5. Consider advanced techniques (ensemble, multi-scale, etc.)")

print("\n" + "="*80)
print("✅ EDA COMPLETED - Ready for model development!")
print("="*80)

## 10. Export Results

In [None]:
# Save key results for later use
import pickle

# Create results directory
os.makedirs('../results/eda', exist_ok=True)

# Save metadata
metadata.to_csv('../results/eda/dataset_metadata.csv', index=False)
print("✅ Saved dataset metadata to ../results/eda/dataset_metadata.csv")

# Save splits
train_data.to_csv('../results/eda/train_split.csv', index=False)
val_data.to_csv('../results/eda/val_split.csv', index=False)
test_data.to_csv('../results/eda/test_split.csv', index=False)
print("✅ Saved train/val/test splits to ../results/eda/")

# Save class mappings and weights
class_info = {
    'class_counts': dict(class_counts),
    'class_weights': class_weights,
    'imbalance_ratio': imbalance_metrics['imbalance_ratio']
}

with open('../results/eda/class_info.pkl', 'wb') as f:
    pickle.dump(class_info, f)
print("✅ Saved class information to ../results/eda/class_info.pkl")

# Save image properties (if analyzed)
if len(image_props) > 0:
    image_props.to_csv('../results/eda/image_properties.csv', index=False)
    print("✅ Saved image properties to ../results/eda/image_properties.csv")

print("\n🎉 EDA results exported successfully!")
print("📁 Check the ../results/eda/ directory for all saved files.")