# Exploratory Data Analysis - PneumoniaMNIST

This notebook provides comprehensive exploratory data analysis of the PneumoniaMNIST dataset for chest X-ray pneumonia classification.

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.model_selection import train_test_split
import medmnist
from medmnist import PneumoniaMNIST
import cv2
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette('husl')

# Set random seed for reproducibility
np.random.seed(42)

: 

## 1. Dataset Overview

In [None]:
# Load the dataset
print("Loading PneumoniaMNIST dataset...")

train_dataset = PneumoniaMNIST(split='train', download=True)
val_dataset = PneumoniaMNIST(split='val', download=True)
test_dataset = PneumoniaMNIST(split='test', download=True)

# Get data and labels
train_images, train_labels = train_dataset.imgs, train_dataset.labels.flatten()
val_images, val_labels = val_dataset.imgs, val_dataset.labels.flatten()
test_images, test_labels = test_dataset.imgs, test_dataset.labels.flatten()

print(f"Dataset shapes:")
print(f"Train: {train_images.shape}, Labels: {train_labels.shape}")
print(f"Validation: {val_images.shape}, Labels: {val_labels.shape}")
print(f"Test: {test_images.shape}, Labels: {test_labels.shape}")

# Dataset info
print(f"\nDataset Information:")
print(f"Total samples: {len(train_images) + len(val_images) + len(test_images)}")
print(f"Image dimensions: {train_images.shape[1:]}")

## 2. Class Distribution Analysis

In [None]:
# Combine all data for overall analysis
all_images = np.concatenate([train_images, val_images, test_images], axis=0)
all_labels = np.concatenate([train_labels, val_labels, test_labels], axis=0)

# Class distribution
class_names = ['Normal', 'Pneumonia']

def analyze_class_distribution(labels, split_name):
    counts = Counter(labels)
    total = len(labels)
    
    print(f"\n{split_name} Class Distribution:")
    for i, class_name in enumerate(class_names):
        count = counts[i]
        percentage = (count / total) * 100
        print(f"{class_name}: {count} ({percentage:.1f}%)")
    
    return counts

# Analyze each split
train_counts = analyze_class_distribution(train_labels, "Training")
val_counts = analyze_class_distribution(val_labels, "Validation")
test_counts = analyze_class_distribution(test_labels, "Test")
all_counts = analyze_class_distribution(all_labels, "Overall")

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

splits = [('Training', train_labels), ('Validation', val_labels), ('Test', test_labels), ('Overall', all_labels)]

for idx, (split_name, labels) in enumerate(splits):
    ax = axes[idx // 2, idx % 2]
    
    counts = Counter(labels)
    values = [counts[i] for i in range(len(class_names))]
    
    bars = ax.bar(class_names, values, alpha=0.8)
    ax.set_title(f'{split_name} Set Class Distribution')
    ax.set_ylabel('Number of Samples')
    
    # Add value labels on bars
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.01*max(values),
                f'{value}\n({value/sum(values)*100:.1f}%)',
                ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 3. Image Visualization and Analysis

In [None]:
# Sample images from each class
def show_sample_images(images, labels, n_samples=8):
    fig, axes = plt.subplots(2, n_samples, figsize=(20, 6))
    
    for class_idx in range(2):
        # Get indices for current class
        class_indices = np.where(labels == class_idx)[0]
        
        # Randomly sample images
        sample_indices = np.random.choice(class_indices, n_samples, replace=False)
        
        for i, idx in enumerate(sample_indices):
            ax = axes[class_idx, i]
            ax.imshow(images[idx], cmap='gray')
            ax.set_title(f'{class_names[class_idx]}\nSample {i+1}')
            ax.axis('off')
    
    plt.suptitle('Sample Images from Each Class', fontsize=16)
    plt.tight_layout()
    plt.show()

show_sample_images(all_images, all_labels)

In [None]:
# Pixel intensity analysis
def analyze_pixel_intensities(images, labels):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Overall histogram
    axes[0].hist(images.flatten(), bins=50, alpha=0.7, density=True)
    axes[0].set_title('Overall Pixel Intensity Distribution')
    axes[0].set_xlabel('Pixel Intensity')
    axes[0].set_ylabel('Density')
    
    # Class-wise histograms
    for class_idx in range(2):
        class_images = images[labels == class_idx]
        axes[1].hist(class_images.flatten(), bins=50, alpha=0.7, 
                    label=class_names[class_idx], density=True)
    
    axes[1].set_title('Pixel Intensity by Class')
    axes[1].set_xlabel('Pixel Intensity')
    axes[1].set_ylabel('Density')
    axes[1].legend()
    
    # Mean intensity comparison
    mean_intensities = []
    for class_idx in range(2):
        class_images = images[labels == class_idx]
        mean_intensity = np.mean(class_images, axis=(1, 2))
        mean_intensities.append(mean_intensity)
        axes[2].hist(mean_intensity, bins=30, alpha=0.7, 
                    label=class_names[class_idx], density=True)
    
    axes[2].set_title('Mean Image Intensity by Class')
    axes[2].set_xlabel('Mean Intensity')
    axes[2].set_ylabel('Density')
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()
    
    return mean_intensities

mean_intensities = analyze_pixel_intensities(all_images, all_labels)

In [None]:
# Statistical analysis of intensities
from scipy import stats

print("Statistical Analysis of Mean Intensities:")
for class_idx in range(2):
    intensities = mean_intensities[class_idx]
    print(f"\n{class_names[class_idx]}:")
    print(f"  Mean: {np.mean(intensities):.3f}")
    print(f"  Std: {np.std(intensities):.3f}")
    print(f"  Min: {np.min(intensities):.3f}")
    print(f"  Max: {np.max(intensities):.3f}")
    print(f"  Median: {np.median(intensities):.3f}")

# Statistical test
t_stat, p_value = stats.ttest_ind(mean_intensities[0], mean_intensities[1])
print(f"\nT-test between classes:")
print(f"T-statistic: {t_stat:.3f}")
print(f"P-value: {p_value:.6f}")
print(f"Significant difference: {'Yes' if p_value < 0.05 else 'No'}")

## 4. Image Quality and Preprocessing Analysis

In [None]:
# Analyze image contrast and other quality metrics
def analyze_image_quality(images, labels):
    metrics = {
        'contrast': [],
        'brightness': [],
        'sharpness': []
    }
    
    for img in images[:1000]:  # Sample for speed
        # Contrast (standard deviation)
        contrast = np.std(img)
        metrics['contrast'].append(contrast)
        
        # Brightness (mean)
        brightness = np.mean(img)
        metrics['brightness'].append(brightness)
        
        # Sharpness (Laplacian variance)
        laplacian = cv2.Laplacian(img, cv2.CV_64F)
        sharpness = laplacian.var()
        metrics['sharpness'].append(sharpness)
    
    return metrics

print("Analyzing image quality metrics (sample of 1000 images)...")
quality_metrics = analyze_image_quality(all_images, all_labels)

# Visualize quality metrics
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

metric_names = ['Contrast', 'Brightness', 'Sharpness']
metric_keys = ['contrast', 'brightness', 'sharpness']

for i, (name, key) in enumerate(zip(metric_names, metric_keys)):
    axes[i].hist(quality_metrics[key], bins=30, alpha=0.7, edgecolor='black')
    axes[i].set_title(f'{name} Distribution')
    axes[i].set_xlabel(name)
    axes[i].set_ylabel('Frequency')
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Demonstrate preprocessing effects
def show_preprocessing_effects(image):
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    # Original
    axes[0, 0].imshow(image, cmap='gray')
    axes[0, 0].set_title('Original (28x28)')
    axes[0, 0].axis('off')
    
    # Resized to 224x224
    resized = cv2.resize(image, (224, 224))
    axes[0, 1].imshow(resized, cmap='gray')
    axes[0, 1].set_title('Resized (224x224)')
    axes[0, 1].axis('off')
    
    # Normalized
    normalized = resized.astype(np.float32) / 255.0
    axes[0, 2].imshow(normalized, cmap='gray')
    axes[0, 2].set_title('Normalized [0,1]')
    axes[0, 2].axis('off')
    
    # 3-channel
    three_channel = np.stack([normalized] * 3, axis=-1)
    axes[0, 3].imshow(three_channel)
    axes[0, 3].set_title('3-Channel RGB')
    axes[0, 3].axis('off')
    
    # Augmentations
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    
    datagen = ImageDataGenerator(
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True
    )
    
    # Generate augmented images
    img_batch = three_channel.reshape(1, 224, 224, 3)
    aug_iter = datagen.flow(img_batch, batch_size=1)
    
    aug_titles = ['Rotation', 'Translation', 'Zoom', 'Horizontal Flip']
    
    for i in range(4):
        aug_img = next(aug_iter)[0]
        axes[1, i].imshow(aug_img[:,:,0], cmap='gray')
        axes[1, i].set_title(f'Augmented: {aug_titles[i]}')
        axes[1, i].axis('off')
    
    plt.suptitle('Preprocessing and Augmentation Pipeline', fontsize=16)
    plt.tight_layout()
    plt.show()

# Show preprocessing for a sample image
sample_idx = np.random.choice(len(all_images))
sample_image = all_images[sample_idx]
show_preprocessing_effects(sample_image)

## 5. Data Imbalance Analysis

In [None]:
# Calculate class weights for handling imbalance
from sklearn.utils.class_weight import compute_class_weight

def calculate_class_weights(labels):
    classes = np.unique(labels)
    class_weights = compute_class_weight('balanced', classes=classes, y=labels)
    class_weight_dict = dict(zip(classes, class_weights))
    return class_weight_dict

# Calculate for each split
train_weights = calculate_class_weights(train_labels)
val_weights = calculate_class_weights(val_labels)
test_weights = calculate_class_weights(test_labels)
all_weights = calculate_class_weights(all_labels)

print("Class Weights for Balanced Training:")
print(f"Training set: {train_weights}")
print(f"Validation set: {val_weights}")
print(f"Test set: {test_weights}")
print(f"Overall: {all_weights}")

In [None]:
# Visualize imbalance ratio
def plot_imbalance_analysis():
    splits = ['Train', 'Val', 'Test', 'Overall']
    labels_list = [train_labels, val_labels, test_labels, all_labels]
    
    imbalance_ratios = []
    minority_percentages = []
    
    for labels in labels_list:
        counts = Counter(labels)
        majority_count = max(counts.values())
        minority_count = min(counts.values())
        
        imbalance_ratio = majority_count / minority_count
        minority_percentage = (minority_count / len(labels)) * 100
        
        imbalance_ratios.append(imbalance_ratio)
        minority_percentages.append(minority_percentage)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Imbalance ratio
    bars1 = axes[0].bar(splits, imbalance_ratios, alpha=0.8, color='skyblue')
    axes[0].set_title('Class Imbalance Ratio (Majority/Minority)')
    axes[0].set_ylabel('Ratio')
    axes[0].axhline(y=1, color='red', linestyle='--', label='Perfect Balance')
    axes[0].legend()
    
    # Add value labels
    for bar, ratio in zip(bars1, imbalance_ratios):
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{ratio:.2f}', ha='center', va='bottom')
    
    # Minority class percentage
    bars2 = axes[1].bar(splits, minority_percentages, alpha=0.8, color='lightcoral')
    axes[1].set_title('Minority Class Percentage')
    axes[1].set_ylabel('Percentage (%)')
    axes[1].axhline(y=50, color='red', linestyle='--', label='Perfect Balance (50%)')
    axes[1].legend()
    
    # Add value labels
    for bar, percentage in zip(bars2, minority_percentages):
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.5,
                    f'{percentage:.1f}%', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

plot_imbalance_analysis()

## 6. Recommendations for Model Training

In [None]:
# Summary and recommendations
print("=" * 60)
print("DATASET ANALYSIS SUMMARY AND RECOMMENDATIONS")
print("=" * 60)

print("📊 DATASET CHARACTERISTICS:")
print(f"• Total samples: {len(all_images):,}")
print(f"• Image size: {all_images.shape[1:]} (grayscale)")
print(f"• Classes: {len(class_names)} ({', '.join(class_names)})")

# Class distribution
normal_count = np.sum(all_labels == 0)
pneumonia_count = np.sum(all_labels == 1)
imbalance_ratio = max(normal_count, pneumonia_count) / min(normal_count, pneumonia_count)

print("⚖️ CLASS IMBALANCE:")
print(f"• Normal: {normal_count:,} ({normal_count/len(all_labels)*100:.1f}%)")
print(f"• Pneumonia: {pneumonia_count:,} ({pneumonia_count/len(all_labels)*100:.1f}%)")
print(f"• Imbalance ratio: {imbalance_ratio:.2f}:1")

print("🎯 TRAINING RECOMMENDATIONS:")
print("1. DATA PREPROCESSING:")
print("   • Resize images from 28x28 to 224x224 for transfer learning")
print("   • Normalize pixel values to [0, 1] range")
print("   • Convert grayscale to 3-channel by stacking")
print("   • Apply data augmentation (rotation ≤10°, translation, zoom, horizontal flip)")

print("2. CLASS IMBALANCE HANDLING:")
print(f"   • Use class weights: {all_weights}")
print("   • Consider stratified sampling for train/val/test splits")
print("   • Focus on sensitivity/recall for pneumonia detection")

print("3. MODEL ARCHITECTURE:")
print("   • Start with MobileNetV2 for quick baseline")
print("   • Upgrade to ResNet50 for better performance")
print("   • Use transfer learning with ImageNet weights")
print("   • Add dropout (0.3) and global average pooling")

print("4. TRAINING STRATEGY:")
print("   • Phase 1: Freeze backbone, train head (5-10 epochs)")
print("   • Phase 2: Unfreeze top layers, fine-tune (10-15 epochs)")
print("   • Use lower learning rate for fine-tuning (1e-4 → 1e-5)")
print("   • Early stopping with patience=6, monitor val_auc")

print("5. EVALUATION METRICS:")
print("   • Primary: AUC-ROC (handles class imbalance well)")
print("   • Clinical: Sensitivity (recall) for pneumonia detection")
print("   • Additional: Specificity, F1-score, precision")
print("   • Use calibration for reliable probability estimates")

print("6. EXPLAINABILITY:")
print("   • Implement Grad-CAM for visual explanations")
print("   • Focus on lung field regions in heatmaps")
print("   • Analyze false positives/negatives with Grad-CAM")

print("⚠️ LIMITATIONS TO CONSIDER:")
print("   • Small dataset size (may limit generalization)")
print("   • Low resolution images (28x28 original)")
print("   • Class imbalance requires careful handling")
print("   • Not suitable for clinical deployment without validation")

## 7. Next Steps

Based on this analysis, the next steps in the project should be:

1. **Data Preparation**: Run `prepare_data.py` to preprocess and split the data
2. **Baseline Model**: Train a simple CNN or MobileNetV2 model
3. **Advanced Model**: Implement ResNet50 with transfer learning
4. **Evaluation**: Comprehensive evaluation with multiple metrics
5. **Explainability**: Implement Grad-CAM for model interpretability
6. **Demo Application**: Create Gradio/Streamlit app for interactive use

The analysis shows that while the dataset is relatively small and imbalanced, it provides a good foundation for building a pneumonia detection system with proper preprocessing and training strategies.