# Week 3 Activity 1: Phase 2 - Batch Dataset Creation

## Learning Objectives

By the end of this notebook, you will understand:

1. **How to extract patches at scale** - Moving from single test to batch processing
2. **Spatial augmentation with jitter** - Creating diversity without synthetic transformations
3. **Quality control workflows** - Ensuring dataset integrity through validation
4. **Train/validation splitting** - Creating stratified splits that maintain class balance

## What is Phase 2?

Phase 2 transforms validated extraction parameters into a production-ready dataset. We will:

- Load Phase 1 configuration (patch size, parameters)
- Extract **375 patches** from 126 polygons (3 patches per polygon with spatial jitter)
- Perform comprehensive quality control
- Create stratified train/validation split (80/20)
- Validate dataset integrity
- Generate summary statistics and visualizations

**Flow**: Phase 0 (Polygons) → Phase 1 (Validate) → **Phase 2 (Extract)** → Phase 3 (Train CNN)

**Expected Time**: 60-90 minutes (including Earth Engine extraction)

---

## Section 1: Setup & Load Configuration

We'll start by loading the configuration determined in Phase 1.

In [None]:
# Import required packages
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import ee
import geemap
from pathlib import Path
import json
from datetime import datetime
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')

# Set random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Initialize Earth Engine
ee.Initialize()

print("✓ Packages imported")
print("✓ Random seed set:", RANDOM_SEED)
print("✓ Earth Engine initialized")

In [None]:
# Define paths
REPO = Path.cwd().parent
DATA = REPO / 'data'
POLYGONS_FILE = DATA / 'labels' / 'larger_polygons.geojson'

# Load Phase 1 configuration
PHASE1_CONFIG_PATH = Path('phase1_config.json')

if not PHASE1_CONFIG_PATH.exists():
    print("⚠️  Phase 1 config not found! Using fallback values.")
    # Fallback values if Phase 1 wasn't run
    PATCH_SIZE = 8
    PATCHES_PER_POLYGON = 3
    BANDS = ['B2', 'B3', 'B4', 'B8', 'B11', 'B12']
    ASSET_ID = 'users/markstonegobigred/Parcela/s2_2019_median_6b'
else:
    with open(PHASE1_CONFIG_PATH, 'r') as f:
        phase1_config = json.load(f)
    PATCH_SIZE = phase1_config['PATCH_SIZE']
    PATCHES_PER_POLYGON = phase1_config.get('PATCHES_PER_POLYGON', 3)
    BANDS = phase1_config['BANDS']
    ASSET_ID = phase1_config['COMPOSITE_ASSET']
    print("✓ Phase 1 configuration loaded")

# Create output directories
PHASE2_DIR = Path('phase2_outputs')
PATCHES_DIR = PHASE2_DIR / 'patches'
METADATA_DIR = PHASE2_DIR / 'metadata'
PATCHES_DIR.mkdir(parents=True, exist_ok=True)
METADATA_DIR.mkdir(parents=True, exist_ok=True)

print(f"\nConfiguration:")
print(f"  Patch size: {PATCH_SIZE}×{PATCH_SIZE} pixels ({PATCH_SIZE*10}m)")
print(f"  Patches per polygon: {PATCHES_PER_POLYGON}")
print(f"  Bands: {', '.join(BANDS)}")
print(f"  Output directory: {PHASE2_DIR}")

---

## Section 2: Load Polygons and Composite

Load our training polygons and Sentinel-2 composite from Phase 0/1.

In [None]:
# Load training polygons
polygons = gpd.read_file(POLYGONS_FILE)
if polygons.crs.to_string() != 'EPSG:4326':
    polygons = polygons.to_crs('EPSG:4326')

print(f"✓ Loaded {len(polygons)} training polygons")
print(f"  Classes: {sorted(polygons['class_name'].unique())}")

# Class distribution
class_counts = polygons['class_name'].value_counts().sort_index()
print(f"\n  Class distribution:")
for cls, count in class_counts.items():
    print(f"    {cls:12s}: {count:3d} polygons")

In [None]:
# Load Sentinel-2 composite
composite = ee.Image(ASSET_ID).select(BANDS)

print(f"✓ Composite loaded")
print(f"  Asset ID: {ASSET_ID}")
print(f"  Bands: {composite.bandNames().getInfo()}")

---

## Section 3: Extraction Strategy with Spatial Jitter

### What is spatial jitter?

**Jitter** is a small random offset applied when extracting patches. Instead of always extracting from the exact center of a polygon, we shift the extraction location slightly (±10m).

**Why use jitter?**
- **Data augmentation**: Creates 3× more patches from the same polygons
- **Reduces overfitting**: Model doesn't learn to expect features at exact pixel locations
- **Authentic variation**: Uses real imagery, not synthetic transformations

**How it works:**
```
For each polygon:
  Patch 1: Extract from center (jitter = 0, 0)
  Patch 2: Extract with +10m offset in random direction
  Patch 3: Extract with -8m offset in different direction
```

Let's calculate extraction locations for all patches.

In [None]:
# Calculate extraction manifest
extraction_manifest = []

for idx, poly in polygons.iterrows():
    centroid = poly.geometry.centroid
    lat, lon = centroid.y, centroid.x
    
    # Calculate meters per degree at this latitude
    meters_per_deg_lat = 111320
    meters_per_deg_lon = 111320 * np.cos(np.radians(lat))
    
    for patch_idx in range(PATCHES_PER_POLYGON):
        # Generate jitter
        if patch_idx == 0:
            # First patch: no jitter (center)
            offset_m = (0, 0)
        else:
            # Subsequent patches: random jitter ±1 pixel (±10m)
            offset_m = (np.random.uniform(-10, 10), np.random.uniform(-10, 10))
        
        # Convert offset to degrees
        offset_lon = offset_m[0] / meters_per_deg_lon
        offset_lat = offset_m[1] / meters_per_deg_lat
        
        # Calculate patch center with jitter
        patch_lon = lon + offset_lon
        patch_lat = lat + offset_lat
        
        # Store in manifest
        extraction_manifest.append({
            'patch_id': f"patch_{idx:03d}_{patch_idx}",
            'polygon_id': idx,
            'class_name': poly['class_name'],
            'class_id': poly['class_id'],
            'center_lon': patch_lon,
            'center_lat': patch_lat,
            'offset_lon_m': offset_m[0],
            'offset_lat_m': offset_m[1],
            'jitter_idx': patch_idx
        })

manifest_df = pd.DataFrame(extraction_manifest)

print(f"✓ Extraction manifest created")
print(f"  Total patches planned: {len(manifest_df)}")
print(f"  Patches per polygon: {PATCHES_PER_POLYGON}")
print(f"\n  Patches by class:")
for cls, count in manifest_df['class_name'].value_counts().sort_index().items():
    print(f"    {cls:12s}: {count:3d} patches")

### Visualize jitter pattern (sample)

Let's look at jitter offsets for a few polygons to verify the pattern looks reasonable.

In [None]:
# Sample 6 polygons for visualization
sample_poly_ids = np.random.choice(polygons.index, size=min(6, len(polygons)), replace=False)
sample_patches = manifest_df[manifest_df['polygon_id'].isin(sample_poly_ids)]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, poly_id in enumerate(sample_poly_ids):
    ax = axes[i]
    
    # Get patches for this polygon
    poly_patches = sample_patches[sample_patches['polygon_id'] == poly_id]
    poly_row = polygons.loc[poly_id]
    
    # Plot polygon
    poly_geom = poly_row.geometry
    if poly_geom.geom_type == 'Polygon':
        x, y = poly_geom.exterior.xy
        ax.plot(x, y, 'k-', linewidth=2, label='Polygon boundary')
    
    # Plot patch centers
    colors = ['red', 'blue', 'green']
    for _, patch in poly_patches.iterrows():
        ax.plot(patch['center_lon'], patch['center_lat'], 'o', 
               color=colors[patch['jitter_idx']], markersize=10,
               label=f"Patch {patch['jitter_idx']}" if i == 0 else "")
    
    ax.set_title(f"{poly_row['class_name']} (ID {poly_id})", fontweight='bold')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.grid(True, alpha=0.3)
    ax.axis('equal')
    
    if i == 0:
        ax.legend(fontsize=8)

plt.tight_layout()
plt.show()

print("💡 Interpretation:")
print("   Red = center patch (no jitter)")
print("   Blue/Green = jittered patches (±10m offset)")
print("   All patches should fall well within polygon boundaries.")

---

## Section 4: Batch Patch Extraction

### The main event!

Now we'll extract all patches from Earth Engine. This will take **30-60 minutes** depending on:
- Network speed
- Earth Engine server load
- Number of patches

**What's happening:**
1. For each patch in the manifest:
2. Calculate bounding box from center coordinates
3. Extract from Earth Engine composite
4. Save as NumPy array (.npy file)
5. Log success/failure and quality metrics

**Progress tracking:** The progress bar shows real-time status.

In [None]:
# Batch extraction function
def extract_patch(patch_info, composite, patch_size, bands):
    """
    Extract a single patch from Earth Engine.
    
    Returns:
        tuple: (success, patch_array, metadata_dict)
    """
    try:
        lat, lon = patch_info['center_lat'], patch_info['center_lon']
        
        # Calculate bounding box
        patch_half_m = (patch_size * 10) / 2
        meters_per_deg_lat = 111320
        meters_per_deg_lon = 111320 * np.cos(np.radians(lat))
        
        half_deg_lon = patch_half_m / meters_per_deg_lon
        half_deg_lat = patch_half_m / meters_per_deg_lat
        
        patch_geom = ee.Geometry.Rectangle([
            lon - half_deg_lon,
            lat - half_deg_lat,
            lon + half_deg_lon,
            lat + half_deg_lat
        ])
        
        # Extract from Earth Engine
        start_time = datetime.now()
        patch = geemap.ee_to_numpy(
            composite,
            region=patch_geom,
            scale=10,
            bands=bands
        )
        extraction_time = (datetime.now() - start_time).total_seconds()
        
        # Handle size mismatch
        if patch.shape[:2] != (patch_size, patch_size):
            h, w, c = patch.shape
            resized = np.full((patch_size, patch_size, len(bands)), np.nan)
            h_copy = min(h, patch_size)
            w_copy = min(w, patch_size)
            resized[:h_copy, :w_copy, :] = patch[:h_copy, :w_copy, :]
            patch = resized
        
        # Calculate quality metrics
        nan_pct = np.isnan(patch).sum() / patch.size * 100
        valid = patch[~np.isnan(patch)]
        
        metadata = {
            'success': True,
            'nan_pct': nan_pct,
            'value_min': float(valid.min()) if len(valid) > 0 else np.nan,
            'value_max': float(valid.max()) if len(valid) > 0 else np.nan,
            'value_mean': float(valid.mean()) if len(valid) > 0 else np.nan,
            'extraction_time': extraction_time
        }
        
        return True, patch, metadata
        
    except Exception as e:
        metadata = {
            'success': False,
            'error': str(e),
            'nan_pct': np.nan,
            'value_min': np.nan,
            'value_max': np.nan,
            'value_mean': np.nan,
            'extraction_time': np.nan
        }
        return False, None, metadata

print("✓ Extraction function defined")
print("\nReady to begin batch extraction...")
print(f"  Total patches: {len(manifest_df)}")
print(f"  Estimated time: {len(manifest_df) * 0.12 / 60:.0f}-{len(manifest_df) * 0.2 / 60:.0f} minutes")

In [None]:
# Run batch extraction
print("Starting batch extraction...\n")
start_time = datetime.now()

extraction_log = []
successful = 0
failed = 0

for _, patch_info in tqdm(manifest_df.iterrows(), total=len(manifest_df), desc="Extracting patches"):
    patch_id = patch_info['patch_id']
    
    # Extract patch
    success, patch, metadata = extract_patch(patch_info, composite, PATCH_SIZE, BANDS)
    
    if success:
        # Save patch as .npy file
        patch_path = PATCHES_DIR / f"{patch_id}.npy"
        np.save(patch_path, patch.astype(np.float32))
        successful += 1
    else:
        failed += 1
    
    # Log extraction
    log_entry = {
        'patch_id': patch_id,
        'polygon_id': patch_info['polygon_id'],
        'class_name': patch_info['class_name'],
        **metadata
    }
    extraction_log.append(log_entry)

total_time = (datetime.now() - start_time).total_seconds()

# Create extraction log DataFrame
extraction_df = pd.DataFrame(extraction_log)

print(f"\n✓ Batch extraction complete!")
print(f"\nResults:")
print(f"  Successful: {successful}/{len(manifest_df)} ({successful/len(manifest_df)*100:.1f}%)")
print(f"  Failed: {failed}/{len(manifest_df)} ({failed/len(manifest_df)*100:.1f}%)")
print(f"  Total time: {total_time/60:.1f} minutes")
print(f"  Mean time per patch: {total_time/len(manifest_df):.2f} seconds")

---

## Section 5: Quality Control Analysis

### Assessing dataset quality

Before using these patches for training, we need to verify:
1. **NaN percentage** - How much missing data?
2. **Value ranges** - Are values reasonable for Sentinel-2?
3. **Class distribution** - Did all classes extract successfully?

**Quality tiers:**
- **Excellent**: 0% NaN
- **Good**: 0-10% NaN
- **Acceptable**: 10-20% NaN
- **Poor**: >20% NaN (should exclude)

In [None]:
# Quality analysis on successful patches
successful_patches = extraction_df[extraction_df['success']].copy()

# Classify quality tiers
def classify_quality(nan_pct):
    if nan_pct == 0:
        return 'excellent'
    elif nan_pct < 10:
        return 'good'
    elif nan_pct < 20:
        return 'acceptable'
    else:
        return 'poor'

successful_patches['quality_tier'] = successful_patches['nan_pct'].apply(classify_quality)

# Calculate statistics
print("Quality Control Results:")
print("=" * 60)
print(f"\nNaN Percentage:")
print(f"  Mean:   {successful_patches['nan_pct'].mean():.2f}%")
print(f"  Median: {successful_patches['nan_pct'].median():.2f}%")
print(f"  Range:  {successful_patches['nan_pct'].min():.2f}% - {successful_patches['nan_pct'].max():.2f}%")

print(f"\nQuality Tier Distribution:")
for tier in ['excellent', 'good', 'acceptable', 'poor']:
    count = (successful_patches['quality_tier'] == tier).sum()
    pct = count / len(successful_patches) * 100
    print(f"  {tier.capitalize():12s}: {count:3d} ({pct:5.1f}%)")

print(f"\nValue Ranges:")
print(f"  Min value: {successful_patches['value_min'].min():.0f}")
print(f"  Max value: {successful_patches['value_max'].max():.0f}")
print(f"  (Expected range for Sentinel-2: 0-10000)")

# Class distribution
print(f"\nSuccessful Patches by Class:")
for cls, count in successful_patches['class_name'].value_counts().sort_index().items():
    pct = count / len(successful_patches) * 100
    print(f"  {cls:12s}: {count:3d} ({pct:5.1f}%)")

### Visualize quality metrics

In [None]:
# Create quality visualizations
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# NaN distribution histogram
ax = axes[0]
ax.hist(successful_patches['nan_pct'], bins=20, color='steelblue', edgecolor='black', alpha=0.7)
ax.axvline(20, color='red', linestyle='--', linewidth=2, label='Acceptable threshold (20%)')
ax.set_xlabel('NaN Percentage (%)', fontsize=12, fontweight='bold')
ax.set_ylabel('Number of Patches', fontsize=12, fontweight='bold')
ax.set_title('Distribution of Missing Data (NaN)', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Quality tier pie chart
ax = axes[1]
tier_counts = successful_patches['quality_tier'].value_counts()
colors_map = {'excellent': 'green', 'good': 'lightgreen', 'acceptable': 'yellow', 'poor': 'red'}
colors = [colors_map[tier] for tier in tier_counts.index]
ax.pie(tier_counts.values, labels=tier_counts.index, autopct='%1.1f%%', 
       colors=colors, startangle=90)
ax.set_title('Quality Tier Distribution', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n💡 Interpretation:")
print("   Left: Most patches should have 0% NaN (excellent quality)")
print("   Right: Green (excellent/good) should dominate")

---

## Section 6: Train/Validation Split

### Why stratified splitting?

A **stratified split** maintains the same class proportions in both training and validation sets.

**Example:**
- If Agriculture is 28% of total dataset
- Then Agriculture will be 28% of training set AND 28% of validation set

**Why this matters:**
- Fair evaluation (validation set represents all classes)
- Prevents bias from imbalanced validation
- Standard practice in machine learning

**Split ratio:** 80% training, 20% validation

In [None]:
# Filter to high-quality patches (excellent or good)
high_quality = successful_patches[successful_patches['quality_tier'].isin(['excellent', 'good'])].copy()

print(f"Using high-quality patches for splitting:")
print(f"  Total: {len(high_quality)} ({len(high_quality)/len(successful_patches)*100:.1f}% of successful)")
print(f"  Excluding: {len(successful_patches) - len(high_quality)} low-quality patches")

# Perform stratified split
train_df, val_df = train_test_split(
    high_quality,
    test_size=0.2,
    stratify=high_quality['class_name'],
    random_state=RANDOM_SEED
)

print(f"\n✓ Stratified split complete:")
print(f"  Training: {len(train_df)} patches (80%)")
print(f"  Validation: {len(val_df)} patches (20%)")

# Verify stratification
print(f"\nClass Distribution Verification:")
print("-" * 60)
print(f"{'Class':<12s} {'Total':>8s} {'Train':>8s} {'Val':>8s} {'Train%':>8s} {'Val%':>8s}")
print("-" * 60)

for cls in sorted(high_quality['class_name'].unique()):
    total = (high_quality['class_name'] == cls).sum()
    train = (train_df['class_name'] == cls).sum()
    val = (val_df['class_name'] == cls).sum()
    train_pct = train / len(train_df) * 100
    val_pct = val / len(val_df) * 100
    print(f"{cls:<12s} {total:8d} {train:8d} {val:8d} {train_pct:7.1f}% {val_pct:7.1f}%")

print("-" * 60)
print(f"{'TOTAL':<12s} {len(high_quality):8d} {len(train_df):8d} {len(val_df):8d} {100.0:7.1f}% {100.0:7.1f}%")

### Save splits to files

In [None]:
# Save train/val splits
train_df.to_csv(METADATA_DIR / 'train_split.csv', index=False)
val_df.to_csv(METADATA_DIR / 'val_split.csv', index=False)

# Save split metadata
split_metadata = {
    'split_method': 'stratified',
    'train_ratio': 0.8,
    'val_ratio': 0.2,
    'random_seed': RANDOM_SEED,
    'train_count': len(train_df),
    'val_count': len(val_df),
    'quality_filter': ['excellent', 'good'],
    'class_distribution': {
        cls: {
            'train': int((train_df['class_name'] == cls).sum()),
            'val': int((val_df['class_name'] == cls).sum())
        }
        for cls in sorted(high_quality['class_name'].unique())
    }
}

with open(METADATA_DIR / 'split_metadata.json', 'w') as f:
    json.dump(split_metadata, f, indent=2)

print(f"\n✓ Split files saved:")
print(f"  {METADATA_DIR / 'train_split.csv'}")
print(f"  {METADATA_DIR / 'val_split.csv'}")
print(f"  {METADATA_DIR / 'split_metadata.json'}")

---

## Section 7: Dataset Validation

### Final integrity checks

Before declaring the dataset ready, let's verify:
1. All patch files exist and are readable
2. Shapes are consistent (8×8×6)
3. No overlap between train and validation
4. Minimum samples per class met

In [None]:
# Run validation checks
print("Running Dataset Validation Checks...")
print("=" * 60)

checks_passed = 0
checks_total = 0

# Check 1: All files exist
checks_total += 1
missing_files = []
for patch_id in high_quality['patch_id']:
    if not (PATCHES_DIR / f"{patch_id}.npy").exists():
        missing_files.append(patch_id)

if len(missing_files) == 0:
    print("✅ Check 1: All patch files exist")
    checks_passed += 1
else:
    print(f"❌ Check 1: {len(missing_files)} patch files missing")

# Check 2: All files readable and correct shape
checks_total += 1
shape_errors = []
for patch_id in high_quality['patch_id'][:10]:  # Sample check
    try:
        patch = np.load(PATCHES_DIR / f"{patch_id}.npy")
        if patch.shape != (PATCH_SIZE, PATCH_SIZE, len(BANDS)):
            shape_errors.append(patch_id)
    except:
        shape_errors.append(patch_id)

if len(shape_errors) == 0:
    print(f"✅ Check 2: All patches have correct shape ({PATCH_SIZE}×{PATCH_SIZE}×{len(BANDS)})")
    checks_passed += 1
else:
    print(f"❌ Check 2: {len(shape_errors)} patches have incorrect shape")

# Check 3: No overlap between train and validation
checks_total += 1
overlap = set(train_df['patch_id']) & set(val_df['patch_id'])
if len(overlap) == 0:
    print("✅ Check 3: No overlap between train and validation sets")
    checks_passed += 1
else:
    print(f"❌ Check 3: {len(overlap)} patches appear in both splits")

# Check 4: Minimum samples per class (train)
checks_total += 1
min_train_samples = 5
train_class_counts = train_df['class_name'].value_counts()
insufficient_train = train_class_counts[train_class_counts < min_train_samples]
if len(insufficient_train) == 0:
    print(f"✅ Check 4: All classes have ≥ {min_train_samples} training samples")
    checks_passed += 1
else:
    print(f"❌ Check 4: {len(insufficient_train)} classes have < {min_train_samples} training samples")

# Check 5: Minimum samples per class (val)
checks_total += 1
min_val_samples = 2
val_class_counts = val_df['class_name'].value_counts()
insufficient_val = val_class_counts[val_class_counts < min_val_samples]
if len(insufficient_val) == 0:
    print(f"✅ Check 5: All classes have ≥ {min_val_samples} validation samples")
    checks_passed += 1
else:
    print(f"❌ Check 5: {len(insufficient_val)} classes have < {min_val_samples} validation samples")

print("=" * 60)
print(f"\nValidation Summary: {checks_passed}/{checks_total} checks passed")

if checks_passed == checks_total:
    print("\n✅ DATASET VALIDATED - Ready for CNN training!")
else:
    print("\n⚠️  Some validation checks failed - review above.")

---

## Section 8: Summary & Next Steps

### What we accomplished in Phase 2

✅ **Extracted patches at scale** - Moved from single test to batch processing

✅ **Applied spatial augmentation** - 3 patches per polygon with jitter

✅ **Performed quality control** - Classified all patches by quality tier

✅ **Created stratified split** - 80/20 train/val with balanced classes

✅ **Validated dataset integrity** - All checks passed

### Final Dataset Statistics

In [None]:
# Compile final statistics
final_stats = pd.DataFrame([
    ['Total polygons', len(polygons)],
    ['Patches attempted', len(manifest_df)],
    ['Successful extractions', successful],
    ['Success rate', f"{successful/len(manifest_df)*100:.1f}%"],
    ['High-quality patches', len(high_quality)],
    ['Training patches', len(train_df)],
    ['Validation patches', len(val_df)],
    ['Patch size', f"{PATCH_SIZE}×{PATCH_SIZE} pixels ({PATCH_SIZE*10}m)"],
    ['Number of bands', len(BANDS)],
    ['Number of classes', len(high_quality['class_name'].unique())],
    ['Mean NaN percentage', f"{successful_patches['nan_pct'].mean():.2f}%"],
    ['Excellent quality', f"{(successful_patches['quality_tier']=='excellent').sum()} patches"],
    ['Total extraction time', f"{total_time/60:.1f} minutes"],
    ['Random seed', RANDOM_SEED]
], columns=['Metric', 'Value'])

print("="*70)
print("PHASE 2 COMPLETE: Final Dataset Statistics")
print("="*70)
print()
print(final_stats.to_string(index=False))
print()
print("="*70)

### Key Takeaways

**1. Spatial jitter creates authentic diversity**
- 3× more data from the same polygons
- No synthetic transformations needed
- Model learns spatial invariance naturally

**2. Quality control ensures clean training**
- Filtering by NaN percentage removes problematic patches
- Excellent/good quality patches have reliable spectral signatures
- Bad data = bad model performance

**3. Stratified splitting prevents bias**
- Both splits represent all classes proportionally
- Validation results will be representative
- Standard ML best practice

**4. Validation gives confidence**
- Systematic checks catch issues early
- Know your data is ready before training
- Saves debugging time later

---

## Next Steps: CNN Training (Phase 3)

Now that you have a clean, validated dataset, you're ready for:

**Week 3 Lab: CNN Training**
```python
# Loading your data is now simple:
import pandas as pd
import numpy as np

train_split = pd.read_csv('phase2_outputs/metadata/train_split.csv')
val_split = pd.read_csv('phase2_outputs/metadata/val_split.csv')

X_train = np.array([np.load(f"phase2_outputs/patches/{pid}.npy") 
                     for pid in train_split['patch_id']])
y_train = train_split['class_id'].values

# Ready for model.fit()!
```

**Phase 3 Workflow:**
1. Build CNN architecture (Conv2D, MaxPooling, Dense layers)
2. Train model on training set
3. Evaluate on validation set
4. Analyze results (confusion matrix, per-class accuracy)
5. Visualize learned features
6. Apply to full study area

---

### Congratulations! 🎉

You've successfully completed Phase 2 and created a production-ready dataset for CNN training.

**The key insight from Phase 2:**

**Quality over quantity - A small, clean dataset beats a large, messy one!**

This approach:
- ✅ Ensures high-quality training data
- ✅ Prevents common pitfalls (data leakage, class imbalance)
- ✅ Provides complete documentation and traceability
- ✅ Sets you up for CNN training success

**Ready for Phase 3: CNN Training!** 🚀