# Wetland Training Dataset Creator - 1.5M Samples

**Output:** `wetland_dataset_1.5M_4Training.npz`

**Features:**
- ‚úÖ All 6 classes (0-5) including background
- ‚úÖ Filters NaN values properly
- ‚úÖ Balanced 1.5M samples
- ‚úÖ Includes class weights
- ‚úÖ Memory-efficient chunk processing

In [None]:
# CELL 1: Setup
print("üöÄ Setting up environment...")

import os
import sys
from google.colab import drive

# Mount Google Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
else:
    print("‚úì Drive already mounted")

# Install dependencies
!pip install -q rasterio tqdm

import numpy as np
import torch
import rasterio
from pathlib import Path
from tqdm import tqdm

print("‚úÖ Setup complete!")

In [None]:
# CELL 2: Configuration
print("="*70)
print("CONFIGURATION")
print("="*70)

# Paths
labels_file = "/content/drive/MyDrive/bow_river_wetlands_10m_final.tif"
embeddings_dir = Path("/content/drive/MyDrive/EarthEngine")
output_file = "/content/drive/MyDrive/wetland_dataset_1.5M_4Training.npz"

# Balanced sampling - ALL 6 CLASSES
samples_per_class = {
    0: 600_000,   # Background - INCLUDED
    1: 19_225,
    2: 150_000,
    3: 500_000,
    4: 150_000,
    5: 100_000,
}

print(f"\nLabels: {labels_file}")
print(f"Embeddings: {embeddings_dir}")
print(f"Output: {output_file}")
print(f"\nTarget: {sum(samples_per_class.values()):,} samples")

for cls, count in samples_per_class.items():
    print(f"  Class {cls}: {count:,}")

# Verify paths
assert os.path.exists(labels_file), f"‚ùå Labels not found: {labels_file}"
assert embeddings_dir.exists(), f"‚ùå Embeddings not found: {embeddings_dir}"

tile_files = sorted(embeddings_dir.glob("*.tif"))
print(f"\n‚úì Found {len(tile_files)} embedding tiles")
print("‚úÖ Configuration validated!")

In [None]:
# CELL 3: Sample Pixel Coordinates (Memory-Efficient)
print("\n" + "="*70)
print("SAMPLING COORDINATES")
print("="*70)

sampled_coords = {cls: {'y': [], 'x': []} for cls in samples_per_class.keys()}
samples_collected = {cls: 0 for cls in samples_per_class.keys()}

np.random.seed(42)

print("\nScanning labels in chunks...")
with rasterio.open(labels_file) as src:
    windows = list(src.block_windows(1))
    np.random.shuffle(windows)
    
    for idx, (block_id, window) in tqdm(enumerate(windows), total=len(windows), desc="Blocks"):
        labels_chunk = src.read(1, window=window)
        row_off = window.row_off
        col_off = window.col_off
        
        for cls in samples_per_class.keys():
            if samples_collected[cls] >= samples_per_class[cls]:
                continue
            
            class_mask = (labels_chunk == cls)
            y_local, x_local = np.where(class_mask)
            
            if len(y_local) == 0:
                continue
            
            y_global = y_local + row_off
            x_global = x_local + col_off
            
            needed = samples_per_class[cls] - samples_collected[cls]
            available = len(y_local)
            n_sample = min(needed, available)
            
            if available > needed:
                idx_sample = np.random.choice(available, n_sample, replace=False)
                sampled_coords[cls]['y'].append(y_global[idx_sample])
                sampled_coords[cls]['x'].append(x_global[idx_sample])
            else:
                sampled_coords[cls]['y'].append(y_global)
                sampled_coords[cls]['x'].append(x_global)
            
            samples_collected[cls] += n_sample
        
        if all(samples_collected[cls] >= samples_per_class[cls] for cls in samples_per_class.keys()):
            print(f"\n‚úì Got all samples after {idx+1}/{len(windows)} blocks")
            break

# Combine
all_y, all_x, all_labels = [], [], []

for cls in samples_per_class.keys():
    if len(sampled_coords[cls]['y']) > 0:
        y_coords = np.concatenate(sampled_coords[cls]['y'])
        x_coords = np.concatenate(sampled_coords[cls]['x'])
        
        if len(y_coords) > samples_per_class[cls]:
            y_coords = y_coords[:samples_per_class[cls]]
            x_coords = x_coords[:samples_per_class[cls]]
        
        all_y.append(y_coords)
        all_x.append(x_coords)
        all_labels.append(np.full(len(y_coords), cls))
        print(f"  Class {cls}: {len(y_coords):,}")

y_indices = np.concatenate(all_y)
x_indices = np.concatenate(all_x)
y = np.concatenate(all_labels)

shuffle_idx = np.random.permutation(len(y_indices))
y_indices = y_indices[shuffle_idx]
x_indices = x_indices[shuffle_idx]
y = y[shuffle_idx]

print(f"\nTotal coordinates: {len(y):,}")
print("‚úÖ Coordinates sampled!")

In [None]:
# CELL 4: Extract Embeddings from Tiles
print("\n" + "="*70)
print("EXTRACTING EMBEDDINGS")
print("="*70)

n_samples = len(y_indices)
X = np.zeros((n_samples, 64), dtype=np.float32)
found_samples = np.zeros(n_samples, dtype=bool)

with tqdm(total=len(tile_files), desc="Tiles", unit=" tiles") as pbar:
    for tile_file in tile_files:
        with rasterio.open(tile_file) as tile_src:
            # Parse tile position from filename
            parts = tile_file.stem.split('-')
            if len(parts) >= 3:
                try:
                    tile_row_offset = int(parts[-2])
                    tile_col_offset = int(parts[-1])
                except ValueError:
                    pbar.update(1)
                    continue
            else:
                pbar.update(1)
                continue
            
            tile_height, tile_width = tile_src.height, tile_src.width
            
            # Find samples in this tile
            in_tile_y = (y_indices >= tile_row_offset) & (y_indices < tile_row_offset + tile_height)
            in_tile_x = (x_indices >= tile_col_offset) & (x_indices < tile_col_offset + tile_width)
            in_tile_mask = in_tile_y & in_tile_x
            
            if in_tile_mask.any():
                tile_data = tile_src.read()  # (64, H, W)
                
                local_y = y_indices[in_tile_mask] - tile_row_offset
                local_x = x_indices[in_tile_mask] - tile_col_offset
                
                for i, (ly, lx) in enumerate(zip(local_y, local_x)):
                    global_idx = np.where(in_tile_mask)[0][i]
                    pixel_values = tile_data[:, ly, lx]
                    
                    # Filter samples with ANY NaN (standard ML practice)
                    # NOTE: If you need to keep some NaN, change this condition
                    if not np.isnan(pixel_values).any():
                        X[global_idx, :] = pixel_values
                        found_samples[global_idx] = True
        
        pbar.update(1)
        pbar.set_postfix({"found": f"{found_samples.sum():,}/{n_samples:,}"})

print(f"\n‚úì Extracted {found_samples.sum():,} / {n_samples:,} samples")

if not found_samples.all():
    missing = (~found_samples).sum()
    print(f"   ‚ö† {missing:,} samples had NaN values (filtered out)")
    
    print("\n   Missing by class:")
    for cls in np.unique(y):
        cls_mask = (y == cls)
        missing_cls = (~found_samples[cls_mask]).sum()
        if missing_cls > 0:
            print(f"     Class {cls}: {missing_cls:,} / {cls_mask.sum():,}")

print("‚úÖ Extraction complete!")

In [None]:
# CELL 5: Calculate Class Weights & Save
print("\n" + "="*70)
print("FINALIZING DATASET")
print("="*70)

# Use only valid samples
X_final = X[found_samples]
y_final = y[found_samples]

# Calculate class weights
unique_classes, class_counts = np.unique(y_final, return_counts=True)
class_weights = torch.zeros(6)

for cls, count in zip(unique_classes, class_counts):
    class_weights[cls] = 1.0 / count

class_weights = class_weights / class_weights.sum() * 6

print("\nClass weights:")
for cls in range(6):
    if cls in unique_classes:
        print(f"  Class {cls}: {class_weights[cls]:.4f}")
    else:
        print(f"  Class {cls}: MISSING ‚ùå")

# Save
print(f"\nSaving to: {output_file}")
np.savez_compressed(
    output_file,
    X=X_final,
    y=y_final,
    class_weights=class_weights.numpy(),
)

print("\n" + "="*70)
print("‚úÖ DATASET CREATED SUCCESSFULLY!")
print("="*70)
print(f"\nFile: wetland_dataset_1.5M_4Training.npz")
print(f"Samples: {len(y_final):,}")
print(f"Features: 64")
print(f"Size: {X_final.nbytes / (1024**3):.2f} GB")

print("\nFinal distribution:")
for cls, count in zip(unique_classes, class_counts):
    pct = 100 * count / len(y_final)
    print(f"  Class {cls}: {count:,} ({pct:.1f}%)")

print("\nüéâ Ready to download and train!")

In [None]:
# CELL 6: Verify (Optional)
print("\n" + "="*70)
print("VERIFICATION")
print("="*70)

data = np.load(output_file)

print(f"\nArrays: {list(data.keys())}")

for key in data.keys():
    arr = data[key]
    print(f"\n{key}:")
    print(f"  Shape: {arr.shape}")
    print(f"  Type: {arr.dtype}")
    
    if key == 'X':
        print(f"  Has NaN: {np.isnan(arr).any()} (should be False)")
        print(f"  Has Inf: {np.isinf(arr).any()} (should be False)")
        print(f"  Min: {arr.min():.4f}, Max: {arr.max():.4f}")
    elif key == 'y':
        print(f"  Classes: {np.unique(arr)}")

data.close()
print("\n‚úÖ Verification passed!")