# Wetland Mapping - Dataset Preparation
This notebook prepares the balanced dataset for wetland classification using Google Earth Engine embeddings.

## Kaggle Dataset Setup:
- **Dataset name**: `Bo_River_and_Google_Earth`
- **File structure**:
  - `bow_river_wetlands_10m_final.tif` (labels)
  - `Google_Dataset/` (folder with 77 embedding TIF tiles)

In [None]:
# Install dependencies
!pip install -q rasterio tqdm

In [None]:
# Clone your repository (gee_embed_CNN_dev branch)
!git clone -b gee_embed_CNN_dev https://github.com/Jcub05/Wetland-Mapping-ELEC498-Group-46.git
%cd Wetland-Mapping-ELEC498-Group-46

## Step 1: Link Kaggle Dataset Files
Create symbolic links to avoid copying 5GB of data

In [None]:
import os

# Your Kaggle dataset (Kaggle converts underscores to hyphens in URLs)
KAGGLE_INPUT = '/kaggle/input/bo-river-and-google-earth'

# Verify dataset exists
if not os.path.exists(KAGGLE_INPUT):
    print(f"âš  Error: Dataset not found at {KAGGLE_INPUT}")
    print("Available datasets:")
    !ls -la /kaggle/input/
else:
    print(f"âœ“ Dataset found: {KAGGLE_INPUT}")
    print("\nContents:")
    !ls -lh {KAGGLE_INPUT}
    
# Create symbolic links (no copying needed!)
print("\nLinking dataset files...")
if not os.path.exists('Google_Dataset'):
    os.symlink(os.path.join(KAGGLE_INPUT, 'Google_Dataset'), 'Google_Dataset')
    print("  âœ“ Linked Google_Dataset folder")

if not os.path.exists('bow_river_wetlands_10m_final.tif'):
    os.symlink(
        os.path.join(KAGGLE_INPUT, 'bow_river_wetlands_10m_final.tif'),
        'bow_river_wetlands_10m_final.tif'
    )
    print("  âœ“ Linked labels file")

print("\nâœ“ All files ready!")

## Step 2: Build VRT (Virtual Raster)
Combine all TIF tiles into a single virtual raster

In [None]:
# Run the VRT builder
!python build_vrt_and_verify.py

## Step 3: Load and Analyze Data

In [None]:
import rasterio
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict

# File paths
embeddings_file = "bow_river_embeddings_2020_matched.vrt"
labels_file = "bow_river_wetlands_10m_final.tif"

print("Loading labels...")
with rasterio.open(labels_file) as labels_src:
    labels_full = labels_src.read(1)
    print(f"Labels (original): {labels_full.shape}")

print(f"\nOpening embeddings VRT: {embeddings_file}")
embeddings_src = rasterio.open(embeddings_file)
print(f"Embeddings: {embeddings_src.count} bands x {embeddings_src.height} x {embeddings_src.width}")

# Crop labels to match embeddings
labels = labels_full[:embeddings_src.height, :embeddings_src.width]
print(f"Labels (cropped): {labels.shape}")

# Verify dimensions match
assert (embeddings_src.height, embeddings_src.width) == labels.shape, "Dimension mismatch!"
print("âœ“ Dimensions match!")

In [None]:
# Analyze class distribution
valid_mask = (labels >= 0) & (labels <= 5)
valid_count = valid_mask.sum()
print(f"\nTotal labeled pixels: {valid_count:,} out of {labels.size:,} ({100*valid_count/labels.size:.2f}%)")

unique_classes, class_counts = np.unique(labels[valid_mask], return_counts=True)
print("\nClass distribution:")
for cls, count in zip(unique_classes, class_counts):
    print(f"  Class {cls}: {count:,} pixels ({100*count/valid_count:.2f}%)")

## Step 4: Balanced Sampling (~1.5M samples)

In [None]:
# Balanced sampling strategy
samples_per_class = {
    0: 600_000,   # Background: plenty available, need good "not wetland" examples
    1: 19_225,    # Class 1: USE ALL (smallest class - only 19K available)
    2: 150_000,   # Class 2: moderate wetland type
    3: 500_000,   # Class 3: largest wetland class, get lots of variety  
    4: 150_000,   # Class 4: moderate wetland type
    5: 100_000,   # Class 5: moderate wetland type
}
total_target = sum(samples_per_class.values())
print(f"Balanced sampling strategy (target: {total_target:,} samples)\n")

sampled_indices_y = []
sampled_indices_x = []
sampled_labels = []

print("Sampling pixels from each class...")
for cls in unique_classes:
    class_mask = (labels == cls)
    y_idx, x_idx = np.where(class_mask)
    
    n_available = len(y_idx)
    n_target = samples_per_class[cls]
    n_sample = min(n_target, n_available)
    
    # Random sampling
    if n_available > n_target:
        sample_idx = np.random.choice(n_available, n_target, replace=False)
    else:
        sample_idx = np.arange(n_available)
        print(f"  âš  Class {cls}: only {n_available:,} available (target: {n_target:,})")
    
    sampled_indices_y.append(y_idx[sample_idx])
    sampled_indices_x.append(x_idx[sample_idx])
    sampled_labels.append(np.full(n_sample, cls))
    
    print(f"  Class {cls}: sampled {n_sample:,} / {n_available:,} pixels")

# Combine and shuffle
y_indices = np.concatenate(sampled_indices_y)
x_indices = np.concatenate(sampled_indices_x)
y = np.concatenate(sampled_labels)

np.random.seed(42)  # For reproducibility
shuffle_idx = np.random.permutation(len(y_indices))
y_indices = y_indices[shuffle_idx]
x_indices = x_indices[shuffle_idx]
y = y[shuffle_idx]

print(f"\nTotal balanced samples: {len(y):,}")
unique_sampled, sampled_counts = np.unique(y, return_counts=True)
print("\nSampled distribution:")
for cls, count in zip(unique_sampled, sampled_counts):
    print(f"  Class {cls}: {count:,} samples ({100*count/len(y):.2f}%)")

In [None]:
# Calculate class weights for loss function
class_weights = torch.zeros(6)
for cls, count in zip(unique_sampled, sampled_counts):
    class_weights[cls] = 1.0 / count
class_weights = class_weights / class_weights.sum() * 6  # Normalize

print("Class weights for loss function:")
for cls in range(6):
    print(f"  Class {cls}: {class_weights[cls]:.4f}")
print("\nðŸ’¡ Use in training: nn.CrossEntropyLoss(weight=class_weights)")

## Step 5: Extract Embeddings (Optimized Batch Reading)
**This is the slow part** - reading 1.5M pixels from disk.  
Optimized to read entire rows at once (~100x faster than pixel-by-pixel).

**Expected time: 10-20 minutes**

In [None]:
# Extract embeddings using row-batched reading
print("Reading embeddings for sampled pixels (optimized batching)...")
n_samples = len(y_indices)
X = np.zeros((n_samples, embeddings_src.count), dtype=np.float32)

# Group samples by row for efficient batch reading
row_to_samples = defaultdict(list)
for idx, (y_coord, x_coord) in enumerate(zip(y_indices, x_indices)):
    row_to_samples[y_coord].append((idx, x_coord))

print(f"Grouped {n_samples:,} samples into {len(row_to_samples):,} unique rows")
print("This will process rows in batches...\n")

# Read row by row
sample_count = 0
with tqdm(total=len(row_to_samples), desc="Reading rows", unit=" rows") as pbar:
    for row_idx in sorted(row_to_samples.keys()):
        # Read entire row at once (64 bands x 31,428 pixels)
        row_data = embeddings_src.read(window=((row_idx, row_idx+1), (0, embeddings_src.width)))
        row_data = row_data[:, 0, :]  # Shape: (64, width)
        
        # Extract sampled pixels from this row
        for sample_idx, col_idx in row_to_samples[row_idx]:
            X[sample_idx, :] = row_data[:, col_idx]
            sample_count += 1
        
        pbar.update(1)

embeddings_src.close()
print(f"\nâœ“ Successfully loaded {sample_count:,} samples")
print(f"  X shape: {X.shape}")
print(f"  Memory: {X.nbytes / (1024**3):.2f} GB")
print(f"  y shape: {y.shape}")

## Step 6: Save Preprocessed Dataset

In [None]:
# Save dataset as compressed .npz file
output_file = 'wetland_dataset_1.5M.npz'
np.savez_compressed(
    output_file,
    X=X,
    y=y,
    class_weights=class_weights.numpy(),
    samples_per_class=np.array(list(samples_per_class.values()))
)

import os
file_size_gb = os.path.getsize(output_file) / (1024**3)
print(f"âœ“ Dataset saved to: {output_file}")
print(f"  File size: {file_size_gb:.2f} GB")
print(f"\nðŸ“¥ Download this file to use for training!")

## Step 7: Verify Saved Dataset

In [None]:
# Load and verify
data = np.load(output_file)
print("Dataset contents:")
for key in data.files:
    print(f"  {key}: {data[key].shape}")

print(f"\nâœ“ Final class distribution:")
unique, counts = np.unique(data['y'], return_counts=True)
for cls, count in zip(unique, counts):
    print(f"  Class {cls}: {count:,} samples ({100*count/len(data['y']):.2f}%)")

print(f"\n\nTo use in training:")
print("```python")
print("data = np.load('wetland_dataset_1.5M.npz')")
print("X, y = data['X'], data['y']")
print("class_weights = torch.from_numpy(data['class_weights'])")
print("```")