# 01 — Data Exploration (EDA)

In this notebook we explore the **HAM10000** dataset to understand:
- How many images per class (class distribution)
- What the images look like (sample grid)
- Image dimensions and properties
- Class imbalance — and why it matters for medical data

> **Dataset:** HAM10000 — 10,015 dermatoscopic images of 7 skin lesion types

---

## Setup

**Before running this notebook:**
1. Download HAM10000 from Kaggle: search "Skin Cancer MNIST: HAM10000"
2. Extract into `data/HAM10000/`
3. You should have:
   - `data/HAM10000/HAM10000_images_part_1/` (images)
   - `data/HAM10000/HAM10000_images_part_2/` (images)
   - `data/HAM10000/HAM10000_metadata.csv` (labels)

Or run the download cell below if you have the Kaggle API set up.

In [None]:
# OPTIONAL: Download dataset via Kaggle API
# Uncomment the lines below if you have kaggle CLI installed
# !pip install kaggle
# !kaggle datasets download -d kmader/skin-cancer-mnist-ham10000 -p ../data/HAM10000 --unzip

In [None]:
import sys
sys.path.append('..')

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
from collections import Counter

from src.config import DATA_DIR, CLASS_NAMES, CLASS_LABELS, RESULTS_DIR

sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 120

# Verify data folder
print(f'Data dir: {DATA_DIR}')
if DATA_DIR.exists():
    print(f'Contents: {os.listdir(DATA_DIR)}')
else:
    print('ERROR: Data folder not found!')
    print('Download HAM10000 from Kaggle and extract to data/HAM10000/')

print('\nSetup complete!')

---
## 1. Load Metadata

In [None]:
# Load the metadata CSV
metadata_path = DATA_DIR / 'HAM10000_metadata.csv'
df = pd.read_csv(metadata_path)

print(f'Total samples: {len(df)}')
print(f'Columns: {list(df.columns)}')
print()
df.head(10)

In [None]:
# Basic info
print('Data types:')
print(df.dtypes)
print()
print('Missing values:')
print(df.isnull().sum())
print()
print('Unique values per column:')
for col in df.columns:
    print(f'  {col}: {df[col].nunique()}')

---
## 2. Class Distribution

This is **the most important chart** for understanding our data.
Medical datasets are almost always imbalanced — some conditions are much more common than others.

In [None]:
# Count per class
class_counts = df['dx'].value_counts()
print('Samples per class:')
for cls, count in class_counts.items():
    label = CLASS_LABELS.get(cls, cls)
    pct = count / len(df) * 100
    print(f'  {label:30s}  ({cls})  ->  {count:5d}  ({pct:.1f}%)')

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart
labels = [CLASS_LABELS[c] for c in class_counts.index]
colors = sns.color_palette('husl', len(class_counts))

bars = axes[0].barh(labels, class_counts.values, color=colors)
axes[0].set_xlabel('Number of Images')
axes[0].set_title('Class Distribution - HAM10000')
axes[0].invert_yaxis()

# Add count labels on bars
for bar, count in zip(bars, class_counts.values):
    axes[0].text(bar.get_width() + 50, bar.get_y() + bar.get_height()/2,
                 f'{count}', va='center', fontsize=10)

# Pie chart
axes[1].pie(class_counts.values, labels=labels, autopct='%1.1f%%',
            colors=colors, startangle=90)
axes[1].set_title('Class Proportions')

plt.tight_layout()
fig.savefig(RESULTS_DIR / 'class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\nSaved to {RESULTS_DIR / "class_distribution.png"}')

### Key Observation: Class Imbalance

Notice how **Melanocytic Nevus (nv)** dominates the dataset (~67%), while **Dermatofibroma** and **Vascular Lesion** have very few samples.

This is a classic medical data problem. We'll handle it in the training notebook using:
- **Weighted loss function** (penalize misclassifying rare classes more)
- **Data augmentation** (generate more samples for underrepresented classes)
- **Stratified splitting** (ensure all classes are represented in train/val/test)

---
## 3. Sample Images

Let's see what each class actually looks like.

In [None]:
# Build a lookup: image_id -> file path
# HAM10000 splits images across two folders
image_dirs = [
    DATA_DIR / 'HAM10000_images_part_1',
    DATA_DIR / 'HAM10000_images_part_2',
]

# Build a dict for fast lookup
image_path_map = {}
for d in image_dirs:
    if d.exists():
        for f in d.iterdir():
            if f.suffix == '.jpg':
                image_path_map[f.stem] = f

print(f'Found {len(image_path_map)} images')

def find_image(image_id):
    return image_path_map.get(image_id)

# Quick test
sample_id = df['image_id'].iloc[0]
path = find_image(sample_id)
if path:
    img = Image.open(path)
    print(f'Sample image: {path.name} | Size: {img.size} | Mode: {img.mode}')
else:
    print(f'ERROR: Could not find image {sample_id}')

In [None]:
# Show sample images per class (3 samples each)
n_samples = 3
fig, axes = plt.subplots(len(CLASS_NAMES), n_samples, figsize=(12, 4 * len(CLASS_NAMES)))

for row, cls in enumerate(CLASS_NAMES):
    class_df = df[df['dx'] == cls].sample(n=n_samples, random_state=42)

    for col, (_, sample) in enumerate(class_df.iterrows()):
        img_path = find_image(sample['image_id'])
        if img_path:
            img = Image.open(img_path)
            axes[row, col].imshow(img)

        if col == 0:
            axes[row, col].set_ylabel(CLASS_LABELS[cls], fontsize=11, fontweight='bold')

        axes[row, col].set_xticks([])
        axes[row, col].set_yticks([])

plt.suptitle('Sample Images per Class', fontsize=16, fontweight='bold', y=1.01)
plt.tight_layout()
fig.savefig(RESULTS_DIR / 'sample_images.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 4. Image Properties

In [None]:
# Check image dimensions for a random sample
sample_df = df.sample(n=200, random_state=42)
widths = []
heights = []

for _, row in sample_df.iterrows():
    img_path = find_image(row['image_id'])
    if img_path:
        img = Image.open(img_path)
        widths.append(img.size[0])
        heights.append(img.size[1])

print(f'Image dimensions (sample of {len(widths)}):')
print(f'  Width  - min: {min(widths)}, max: {max(widths)}, mean: {np.mean(widths):.0f}')
print(f'  Height - min: {min(heights)}, max: {max(heights)}, mean: {np.mean(heights):.0f}')
print(f'  Most common size: {Counter(zip(widths, heights)).most_common(1)[0]}')

---
## 5. Patient Analysis

Important: some patients have **multiple images**. We need to make sure images from the same patient don't end up in both train and test sets (data leakage!).

In [None]:
if 'lesion_id' in df.columns:
    unique_lesions = df['lesion_id'].nunique()
    print(f'Total images: {len(df)}')
    print(f'Unique lesions: {unique_lesions}')
    print(f'Avg images per lesion: {len(df) / unique_lesions:.1f}')
    print()

    # Distribution of images per lesion
    imgs_per_lesion = df['lesion_id'].value_counts()
    print('Images per lesion:')
    print(imgs_per_lesion.describe())
else:
    print('No lesion_id column - skipping patient analysis')

---
## 6. Age & Sex Distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Age distribution
if 'age' in df.columns:
    df['age'].dropna().hist(bins=30, ax=axes[0], color='steelblue', edgecolor='white')
    axes[0].set_xlabel('Age')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Age Distribution')
    axes[0].axvline(df['age'].median(), color='red', linestyle='--',
                    label=f'Median: {df["age"].median():.0f}')
    axes[0].legend()

# Sex distribution
if 'sex' in df.columns:
    sex_counts = df['sex'].value_counts()
    axes[1].bar(sex_counts.index, sex_counts.values, color=['#4C72B0', '#DD8452'])
    axes[1].set_ylabel('Count')
    axes[1].set_title('Sex Distribution')

plt.tight_layout()
fig.savefig(RESULTS_DIR / 'demographics.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 7. Localization Distribution

In [None]:
if 'localization' in df.columns:
    fig, ax = plt.subplots(figsize=(10, 6))

    loc_counts = df['localization'].value_counts()
    loc_counts.plot(kind='barh', ax=ax, color='steelblue')
    ax.set_xlabel('Number of Images')
    ax.set_title('Lesion Localization (Body Part)')
    ax.invert_yaxis()
    plt.tight_layout()
    fig.savefig(RESULTS_DIR / 'localization.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print('No localization column')

---
## 8. Pixel Intensity Distribution

Let's look at the average color channels across different classes — this can reveal if certain conditions have distinct color profiles.

In [None]:
# Average RGB values per class
fig, axes = plt.subplots(1, len(CLASS_NAMES), figsize=(20, 3))

for i, cls in enumerate(CLASS_NAMES):
    class_df = df[df['dx'] == cls].sample(n=min(20, len(df[df['dx'] == cls])), random_state=42)
    r_vals, g_vals, b_vals = [], [], []

    for _, row in class_df.iterrows():
        img_path = find_image(row['image_id'])
        if img_path:
            img = np.array(Image.open(img_path))
            r_vals.append(img[:,:,0].mean())
            g_vals.append(img[:,:,1].mean())
            b_vals.append(img[:,:,2].mean())

    axes[i].bar(['R', 'G', 'B'],
                [np.mean(r_vals), np.mean(g_vals), np.mean(b_vals)],
                color=['red', 'green', 'blue'], alpha=0.7)
    axes[i].set_title(cls, fontsize=10)
    axes[i].set_ylim(0, 255)
    axes[i].set_yticks([0, 128, 255])

plt.suptitle('Average RGB per Class', fontsize=14, fontweight='bold')
plt.tight_layout()
fig.savefig(RESULTS_DIR / 'rgb_per_class.png', dpi=150, bbox_inches='tight')
plt.show()

---
## Summary

Key findings from our EDA:

1. **Class imbalance** — Melanocytic Nevus dominates, rare classes need special handling
2. **Image dimensions** — All images are the same size, we'll resize to 128x128 for CPU training
3. **Duplicate lesions** — Same lesion can appear multiple times, need careful train/test splitting
4. **Demographics** — Dataset includes age, sex, and body location metadata
5. **Color profiles** — Different conditions may have subtle color differences

### Next Steps

-> Move on to **02_model_training.ipynb** to build and train our classification model!