# Dataset Exploration Notebook

**Federated Learning for Skin Cancer Classification with DSCATNet**

This notebook provides comprehensive dataset exploration and verification, including:

- **Dataset Verification**: Check all datasets are properly downloaded
- **Class Distribution Analysis**: Understand class imbalances across FL clients
- **Image Statistics**: Analyze dimensions, aspect ratios, and pixel values
- **Non-IID Visualization**: Visualize data heterogeneity across clients
- **Preprocessing Pipeline**: Test and visualize augmentation transforms
- **Sample Visualization**: Display sample images from each dataset

---

## 1. Setup and Imports

In [None]:
# Standard library
import sys
from pathlib import Path

# Data science
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
from PIL import Image

# Add project root to path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
%matplotlib inline

print(f"Project root: {project_root}")

## 2. Configuration

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# Data directory
DATA_ROOT = project_root / "data"

# Create data directory if it doesn't exist
DATA_ROOT.mkdir(exist_ok=True)

# Define unified class names (7-class scheme)
UNIFIED_CLASSES = [
    'Actinic Keratosis',
    'Basal Cell Carcinoma', 
    'Benign Keratosis',
    'Dermatofibroma',
    'Melanoma',
    'Melanocytic Nevus',
    'Vascular Lesion'
]

# Class abbreviations
CLASS_ABBREV = ['AK', 'BCC', 'BKL', 'DF', 'MEL', 'NV', 'VASC']

# Define image directories for each client
IMAGE_DIRS = {
    'HAM10000': DATA_ROOT / 'HAM10000' / 'HAM10000_images_part_1',
    'ISIC2018': DATA_ROOT / 'ISIC2018' / 'ISIC2018_Task3_Training_Input',
    'ISIC2019': DATA_ROOT / 'ISIC2019' / 'ISIC_2019_Training_Input',
    'ISIC2020': DATA_ROOT / 'ISIC2020' / 'train'
}

print(f"Data root: {DATA_ROOT}")
print(f"Exists: {DATA_ROOT.exists()}")

## 3. Dataset Verification

First, let's verify that all datasets are properly downloaded and organized.

In [None]:
from src.data.verify import DatasetVerifier

# Run verification
verifier = DatasetVerifier(str(DATA_ROOT))
results = verifier.verify_all(verbose=True)

In [None]:
# Summary statistics
summary = verifier.get_summary_stats()

print("\n=== Summary ===")
print(f"Valid datasets: {summary['valid_datasets']}/4")
print(f"Total images: {summary['total_images']:,}")
print("\nImages per dataset:")
for name, count in summary['images_per_dataset'].items():
    print(f"  {name}: {count:,}")

## 4. Class Distribution Analysis

Understanding class distributions is crucial for federated learning, especially when dealing with non-IID data.

In [None]:
# Collect class distributions from all datasets
distributions = {}

for name, result in results.items():
    if result['class_distribution']:
        distributions[name] = result['class_distribution']

# Display raw distributions
for name, dist in distributions.items():
    print(f"\n{name}:")
    for cls, count in dist.items():
        print(f"  {cls}: {count:,}")

In [None]:
# Visualize class distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

colors = sns.color_palette('husl', 8)

for idx, (name, dist) in enumerate(distributions.items()):
    ax = axes[idx]
    
    classes = list(dist.keys())
    counts = list(dist.values())
    
    bars = ax.bar(classes, counts, color=colors[:len(classes)])
    ax.set_title(f'{name} (Client {idx+1})', fontsize=12, fontweight='bold')
    ax.set_xlabel('Class')
    ax.set_ylabel('Count')
    ax.tick_params(axis='x', rotation=45)
    
    # Add count labels on bars
    for bar, count in zip(bars, counts):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
                f'{count:,}', ha='center', va='bottom', fontsize=8)

plt.suptitle('Class Distribution Across FL Clients (Non-IID)', 
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(DATA_ROOT.parent / 'experiments' / 'class_distribution.png', 
            dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Create unified distribution comparison table
def map_to_unified(dist, dataset_name):
    """Map dataset-specific labels to unified 7-class scheme."""
    unified = {c: 0 for c in CLASS_ABBREV}
    
    mapping = {
        # HAM10000 / ISIC2018 mapping
        'akiec': 'AK', 'AKIEC': 'AK',
        'bcc': 'BCC', 'BCC': 'BCC',
        'bkl': 'BKL', 'BKL': 'BKL',
        'df': 'DF', 'DF': 'DF',
        'mel': 'MEL', 'MEL': 'MEL',
        'nv': 'NV', 'NV': 'NV',
        'vasc': 'VASC', 'VASC': 'VASC',
        # ISIC2019 additional
        'AK': 'AK',
        'SCC': 'BCC',  # Map SCC to BCC (carcinomas)
        # ISIC2020 binary
        'benign': 'NV',
        'malignant': 'MEL'
    }
    
    for label, count in dist.items():
        unified_label = mapping.get(label, None)
        if unified_label:
            unified[unified_label] += count
    
    return unified

# Create comparison DataFrame
unified_dists = {}
for name, dist in distributions.items():
    unified_dists[name] = map_to_unified(dist, name)

df_dist = pd.DataFrame(unified_dists).T
df_dist['Total'] = df_dist.sum(axis=1)

print("\nUnified Class Distribution (7 Classes):")
print(df_dist.to_string())

In [None]:
# Visualize unified distribution as heatmap
fig, ax = plt.subplots(figsize=(12, 5))

# Normalize by row (percentage within each client)
df_pct = df_dist.drop('Total', axis=1)
df_pct = df_pct.div(df_pct.sum(axis=1), axis=0) * 100

sns.heatmap(df_pct, annot=True, fmt='.1f', cmap='YlOrRd', 
            ax=ax, cbar_kws={'label': 'Percentage (%)'})
ax.set_title('Class Distribution per Client (% within each client)', 
             fontsize=12, fontweight='bold')
ax.set_xlabel('Class')
ax.set_ylabel('Dataset (FL Client)')

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("- HAM10000/ISIC2018: Similar balanced 7-class distributions")
print("- ISIC2019: Includes SCC, more diverse")
print("- ISIC2020: Binary only â†’ strong non-IID with other clients")

## 5. Image Statistics

Analyze image dimensions, aspect ratios, and pixel statistics across different datasets.

In [None]:
def sample_image_stats(image_dir, n_samples=500):
    """Sample images and compute statistics."""
    image_dir = Path(image_dir)
    if not image_dir.exists():
        return None
    
    images = list(image_dir.glob("*.jpg"))[:n_samples]
    
    stats = {
        'widths': [],
        'heights': [],
        'aspect_ratios': [],
        'mean_r': [],
        'mean_g': [],
        'mean_b': []
    }
    
    for img_path in images:
        try:
            with Image.open(img_path) as img:
                stats['widths'].append(img.width)
                stats['heights'].append(img.height)
                stats['aspect_ratios'].append(img.width / img.height)
                
                # Sample pixel values (resize for efficiency)
                img_small = img.resize((64, 64))
                arr = np.array(img_small)
                if len(arr.shape) == 3:
                    stats['mean_r'].append(arr[:,:,0].mean())
                    stats['mean_g'].append(arr[:,:,1].mean())
                    stats['mean_b'].append(arr[:,:,2].mean())
        except:
            pass
    
    return stats

# Compute stats using IMAGE_DIRS from configuration
all_stats = {}
for name, path in IMAGE_DIRS.items():
    print(f"Sampling {name}...")
    stats = sample_image_stats(path, n_samples=300)
    if stats and stats['widths']:
        all_stats[name] = stats
        print(f"  Sampled {len(stats['widths'])} images")

In [None]:
# Visualize image dimensions
if all_stats:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Width distribution
    for name, stats in all_stats.items():
        axes[0].hist(stats['widths'], bins=20, alpha=0.5, label=name)
    axes[0].set_xlabel('Width (pixels)')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Image Width Distribution')
    axes[0].legend()
    
    # Height distribution
    for name, stats in all_stats.items():
        axes[1].hist(stats['heights'], bins=20, alpha=0.5, label=name)
    axes[1].set_xlabel('Height (pixels)')
    axes[1].set_ylabel('Count')
    axes[1].set_title('Image Height Distribution')
    axes[1].legend()
    
    # Aspect ratio
    for name, stats in all_stats.items():
        axes[2].hist(stats['aspect_ratios'], bins=20, alpha=0.5, label=name)
    axes[2].set_xlabel('Aspect Ratio (W/H)')
    axes[2].set_ylabel('Count')
    axes[2].set_title('Aspect Ratio Distribution')
    axes[2].axvline(x=1.0, color='red', linestyle='--', label='Square')
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()
else:
    print("No image statistics available. Please download datasets first.")

In [None]:
# Color distribution (important for normalization)
if all_stats:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    channels = ['mean_r', 'mean_g', 'mean_b']
    titles = ['Red Channel Mean', 'Green Channel Mean', 'Blue Channel Mean']
    
    for ax, channel, title in zip(axes, channels, titles):
        for name, stats in all_stats.items():
            if stats[channel]:
                ax.hist(stats[channel], bins=30, alpha=0.5, label=name)
        ax.set_xlabel('Pixel Value (0-255)')
        ax.set_ylabel('Count')
        ax.set_title(title)
        ax.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Print mean values per dataset
    print("\nMean RGB values per dataset:")
    for name, stats in all_stats.items():
        if stats['mean_r']:
            r = np.mean(stats['mean_r'])
            g = np.mean(stats['mean_g'])
            b = np.mean(stats['mean_b'])
            print(f"  {name}: R={r:.1f}, G={g:.1f}, B={b:.1f}")

## 6. Non-IID Visualization

Visualize the non-IID nature of the federated learning setup with different data distributions across clients.

In [None]:
# Create non-IID visualization
if distributions:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 1. Stacked bar chart showing absolute numbers
    ax1 = axes[0]
    clients = list(distributions.keys())
    
    # Get all unique classes
    all_classes = set()
    for dist in distributions.values():
        all_classes.update(dist.keys())
    all_classes = sorted(all_classes)
    
    # Create stacked bar
    bottom = np.zeros(len(clients))
    for cls in all_classes:
        values = [distributions[c].get(cls, 0) for c in clients]
        ax1.bar(clients, values, bottom=bottom, label=cls)
        bottom += values
    
    ax1.set_xlabel('FL Client (Dataset)')
    ax1.set_ylabel('Number of Images')
    ax1.set_title('Data Quantity per Client (Non-IID)')
    ax1.legend(bbox_to_anchor=(1.02, 1), loc='upper left')
    ax1.tick_params(axis='x', rotation=45)
    
    # 2. Radar chart for class presence
    ax2 = axes[1]
    
    # Create presence matrix (binary: has class or not)
    presence = []
    for client in clients:
        row = [1 if distributions[client].get(cls, 0) > 0 else 0 
               for cls in all_classes]
        presence.append(row)
    
    presence_df = pd.DataFrame(presence, index=clients, columns=all_classes)
    sns.heatmap(presence_df, annot=True, cmap='RdYlGn', ax=ax2,
                cbar_kws={'label': 'Class Present'})
    ax2.set_title('Class Presence per Client')
    ax2.set_xlabel('Class')
    ax2.set_ylabel('Client')
    
    plt.tight_layout()
    plt.show()
    
    print("\nNon-IID Analysis:")
    print("- Each client has different class distributions")
    print("- ISIC2020 only has 2 classes (extreme non-IID)")
    print("- Data quantity varies significantly (10k to 33k)")

## 7. Preprocessing Pipeline Test

Test the standardized preprocessing pipeline with different augmentation levels.

In [None]:
from src.data.preprocessing import (
    get_train_transforms,
    get_val_transforms,
    IMAGENET_MEAN,
    IMAGENET_STD
)

# Create transforms
train_transform = get_train_transforms(img_size=224, augmentation_level='medium')
val_transform = get_val_transforms(img_size=224)

print("Training transforms:")
print(train_transform)
print("\nValidation transforms:")
print(val_transform)

In [None]:
def visualize_augmentations(image_path, n_augmentations=6):
    """Visualize multiple augmentations of the same image."""
    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()
    
    # Original
    axes[0].imshow(img_array)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Validation (no augmentation)
    val_result = val_transform(image=img_array)
    val_img = val_result['image'].permute(1, 2, 0).numpy()
    # Denormalize for visualization
    val_img = val_img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
    val_img = np.clip(val_img, 0, 1)
    axes[1].imshow(val_img)
    axes[1].set_title('Validation (224x224)')
    axes[1].axis('off')
    
    # Training augmentations
    for i in range(2, 8):
        aug_result = train_transform(image=img_array)
        aug_img = aug_result['image'].permute(1, 2, 0).numpy()
        aug_img = aug_img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
        aug_img = np.clip(aug_img, 0, 1)
        axes[i].imshow(aug_img)
        axes[i].set_title(f'Augmentation {i-1}')
        axes[i].axis('off')
    
    plt.suptitle('Preprocessing Pipeline Visualization', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Find a sample image using IMAGE_DIRS from configuration
sample_found = False
for img_dir in IMAGE_DIRS.values():
    if img_dir.exists():
        images = list(img_dir.glob("*.jpg"))
        if images:
            visualize_augmentations(images[0])
            sample_found = True
            break

if not sample_found:
    print("No sample images found. Please download datasets first.")

## 8. Sample Visualization

Display sample images from each dataset/client.

In [None]:
def display_samples_per_client(n_samples=4):
    """Display sample images from each FL client."""
    fig, axes = plt.subplots(4, n_samples, figsize=(3*n_samples, 12))
    
    client_names = ['HAM10000 (Client 1)', 'ISIC2018 (Client 2)', 
                    'ISIC2019 (Client 3)', 'ISIC2020 (Client 4)']
    
    for row, (name, img_dir) in enumerate(IMAGE_DIRS.items()):
        if not img_dir.exists():
            for col in range(n_samples):
                axes[row, col].text(0.5, 0.5, 'Not Found', 
                                   ha='center', va='center')
                axes[row, col].axis('off')
            axes[row, 0].set_ylabel(client_names[row], fontsize=12)
            continue
            
        images = list(img_dir.glob("*.jpg"))[:n_samples]
        
        for col, img_path in enumerate(images):
            img = Image.open(img_path)
            axes[row, col].imshow(img)
            axes[row, col].axis('off')
            if col == 0:
                axes[row, col].set_ylabel(client_names[row], fontsize=12)
    
    plt.suptitle('Sample Images from Each FL Client', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

display_samples_per_client()

## 9. Summary and Next Steps

Generate a summary report for the thesis.

In [None]:
# Generate summary table for thesis
summary_data = []
for idx, (name, result) in enumerate(results.items()):
    row = {
        'Client': idx + 1,
        'Dataset': name,
        'Images': result['total_images'],
        'Classes': len(result['class_distribution']) if result['class_distribution'] else 0,
        'Status': 'OK' if result['valid'] else 'FAIL'
    }
    summary_data.append(row)

summary_df = pd.DataFrame(summary_data)
print("\n" + "="*60)
print("DATASET SUMMARY FOR THESIS")
print("="*60)
print(summary_df.to_string(index=False))
print("="*60)

# LaTeX table format
print("\n\nLaTeX Table:")
print("\\begin{table}[h]")
print("\\centering")
print("\\caption{Dermoscopy Datasets Used in Federated Learning Experiments}")
print("\\begin{tabular}{|c|l|r|c|}")
print("\\hline")
print("\\textbf{Client} & \\textbf{Dataset} & \\textbf{Images} & \\textbf{Classes} \\\\")
print("\\hline")
for _, row in summary_df.iterrows():
    print(f"{row['Client']} & {row['Dataset']} & {row['Images']:,} & {row['Classes']} \\\\")
print("\\hline")
print("\\end{tabular}")
print("\\label{tab:datasets}")
print("\\end{table}")

In [None]:
print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("""
1. Download missing datasets (see data/download.py instructions)
2. Verify all datasets are correctly organized
3. Run preprocessing pipeline tests
4. Create IID vs Non-IID split configurations
5. Test DataLoader for each client
6. Proceed to model training with run_fl.py or run_experiment.py
""")