# Wetland Dataset Preparation - Tile-Optimized with Debug Logging

This notebook extracts embeddings using tile-by-tile reading with extensive validation.

In [None]:
# Install dependencies
!pip install -q rasterio tqdm

In [None]:
import rasterio
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict
import os

# File paths for Kaggle
KAGGLE_INPUT = '/kaggle/input/bo-river-and-google-earth'
labels_file = f"{KAGGLE_INPUT}/bow_river_wetlands_10m_final.tif"
embeddings_dir = Path(f"{KAGGLE_INPUT}/Google_Dataset")

print("="*70)
print("WETLAND DATASET PREPARATION - TILE-OPTIMIZED")
print("="*70)

## Step 1: Load Labels and Verify

In [None]:
print("\n[1/7] Loading labels...")
with rasterio.open(labels_file) as labels_src:
    labels_full = labels_src.read(1)
    print(f"  ✓ Labels shape: {labels_full.shape}")
    print(f"  ✓ Labels dtype: {labels_full.dtype}")
    print(f"  ✓ Labels range: [{labels_full.min()}, {labels_full.max()}]")
    
# Sanity check
if labels_full.max() > 5 or labels_full.min() < 0:
    print("  ⚠ WARNING: Labels outside expected range [0, 5]!")
else:
    print("  ✓ Labels in valid range")

## Step 2: Find and Validate Tiles

In [None]:
print("\n[2/7] Finding embedding tiles...")
tile_files = sorted(embeddings_dir.glob("*.tif"))
print(f"  ✓ Found {len(tile_files)} tile files")

if len(tile_files) == 0:
    print(f"  ⚠ ERROR: No tiles found in {embeddings_dir}!")
    print(f"  Directory contents:")
    !ls -lh {embeddings_dir}
else:
    print(f"  Sample filenames:")
    for tile in tile_files[:3]:
        print(f"    - {tile.name}")

## Step 3: Balanced Sampling

In [None]:
print("\n[3/7] Balanced sampling...")

# Analyze class distribution
valid_mask = (labels_full >= 0) & (labels_full <= 5)
unique_classes, class_counts = np.unique(labels_full[valid_mask], return_counts=True)

print("\n  Class distribution in full dataset:")
for cls, count in zip(unique_classes, class_counts):
    print(f"    Class {cls}: {count:,} pixels ({100*count/valid_mask.sum():.2f}%)")

# Sampling strategy
samples_per_class = {0: 600_000, 1: 19_225, 2: 150_000, 3: 500_000, 4: 150_000, 5: 100_000}
print(f"\n  Target: {sum(samples_per_class.values()):,} balanced samples")

# Sample coordinates
sampled_indices_y = []
sampled_indices_x = []
sampled_labels = []

print("\n  Sampling from each class:")
for cls in unique_classes:
    class_mask = (labels_full == cls)
    y_idx, x_idx = np.where(class_mask)
    
    n_available = len(y_idx)
    n_target = samples_per_class[cls]
    n_sample = min(n_target, n_available)
    
    if n_available > n_target:
        sample_idx = np.random.choice(n_available, n_target, replace=False)
    else:
        sample_idx = np.arange(n_available)
        print(f"    ⚠ Class {cls}: only {n_available:,} available (target: {n_target:,})")
    
    sampled_indices_y.append(y_idx[sample_idx])
    sampled_indices_x.append(x_idx[sample_idx])
    sampled_labels.append(np.full(n_sample, cls))
    print(f"    ✓ Class {cls}: sampled {n_sample:,}")

# Combine and shuffle
y_indices = np.concatenate(sampled_indices_y)
x_indices = np.concatenate(sampled_indices_x)
y = np.concatenate(sampled_labels)

np.random.seed(42)
shuffle_idx = np.random.permutation(len(y_indices))
y_indices = y_indices[shuffle_idx]
x_indices = x_indices[shuffle_idx]
y = y[shuffle_idx]

print(f"\n  ✓ Total samples: {len(y):,}")
print(f"  ✓ Coordinate ranges: y=[{y_indices.min()}, {y_indices.max()}], x=[{x_indices.min()}, {x_indices.max()}]")

## Step 4: Calculate Class Weights

In [None]:
print("\n[4/7] Calculating class weights...")
unique_sampled, sampled_counts = np.unique(y, return_counts=True)
class_weights = torch.zeros(6)
for cls, count in zip(unique_sampled, sampled_counts):
    class_weights[cls] = 1.0 / count
class_weights = class_weights / class_weights.sum() * 6

print("  Class weights for nn.CrossEntropyLoss:")
for cls in range(6):
    print(f"    Class {cls}: {class_weights[cls]:.4f}")

## Step 5: Extract Embeddings (Tile-by-Tile with Validation)

In [None]:
print("\n[5/7] Extracting embeddings from tiles...")
print(f"  Processing {len(tile_files)} tiles")

# Pre-allocate
n_samples = len(y_indices)
X = np.zeros((n_samples, 64), dtype=np.float32)
found_samples = np.zeros(n_samples, dtype=bool)

print(f"  ✓ Allocated array: {X.shape} ({X.nbytes / (1024**2):.1f} MB)")

# Process tiles
tiles_processed = 0
total_extracted = 0

with tqdm(total=len(tile_files), desc="  Processing tiles", unit=" tiles") as pbar:
    for tile_idx, tile_file in enumerate(tile_files):
        try:
            with rasterio.open(tile_file) as tile_src:
                # Read tile
                tile_data = tile_src.read()
                
                # Parse coordinates from filename
                parts = tile_file.stem.split('-')
                if len(parts) == 3:
                    tile_row_offset = int(parts[1])
                    tile_col_offset = int(parts[2])
                else:
                    pbar.write(f"    ⚠ Skipping {tile_file.name}: unexpected filename format")
                    continue
                
                # Find samples in this tile
                tile_height, tile_width = tile_src.height, tile_src.width
                in_tile_y = (y_indices >= tile_row_offset) & (y_indices < tile_row_offset + tile_height)
                in_tile_x = (x_indices >= tile_col_offset) & (x_indices < tile_col_offset + tile_width)
                in_tile_mask = in_tile_y & in_tile_x
                
                if in_tile_mask.any():
                    # Get local coordinates
                    local_y = y_indices[in_tile_mask] - tile_row_offset
                    local_x = x_indices[in_tile_mask] - tile_col_offset
                    
                    # VALIDATION: Check coordinates are in bounds
                    if (local_y < 0).any() or (local_y >= tile_height).any():
                        pbar.write(f"    ⚠ ERROR: Y coords out of bounds in {tile_file.name}!")
                        pbar.write(f"      Tile height: {tile_height}, local_y range: [{local_y.min()}, {local_y.max()}]")
                        continue
                    
                    if (local_x < 0).any() or (local_x >= tile_width).any():
                        pbar.write(f"    ⚠ ERROR: X coords out of bounds in {tile_file.name}!")
                        pbar.write(f"      Tile width: {tile_width}, local_x range: [{local_x.min()}, {local_x.max()}]")
                        continue
                    
                    # Extract embeddings
                    for i, (ly, lx) in enumerate(zip(local_y, local_x)):
                        global_idx = np.where(in_tile_mask)[0][i]
                        X[global_idx, :] = tile_data[:, ly, lx]
                        found_samples[global_idx] = True
                    
                    total_extracted += len(local_y)
                    
                    # VALIDATION: Check extracted data
                    if tile_idx == 0:  # First tile - detailed check
                        sample_values = tile_data[:, local_y[0], local_x[0]]
                        pbar.write(f"    ✓ First tile extraction check:")
                        pbar.write(f"      Tile: {tile_file.name}")
                        pbar.write(f"      Samples in tile: {len(local_y)}")
                        pbar.write(f"      Sample data: {sample_values[:5]}... (first 5 of 64)")
                        if np.all(sample_values == 0) or np.all(np.isnan(sample_values)):
                            pbar.write(f"      ⚠⚠⚠ WARNING: Extracted values are all zero/NaN!")

                tiles_processed += 1
        
        except Exception as e:
            pbar.write(f"    ⚠ Error processing {tile_file.name}: {e}")
        
        pbar.update(1)
        pbar.set_postfix({"extracted": f"{total_extracted:,}/{n_samples:,}"})

print(f"\n  ✓ Processed {tiles_processed}/{len(tile_files)} tiles")
print(f"  ✓ Extracted {found_samples.sum():,} / {n_samples:,} samples ({100*found_samples.sum()/n_samples:.1f}%)")

if not found_samples.all():
    print(f"  ⚠ WARNING: {(~found_samples).sum():,} samples not found in tiles!")

## Step 6: Validate Extracted Data

In [None]:
print("\n[6/7] Validating extracted embeddings...")

# Check for NaN/zeros
nan_count = np.isnan(X).sum()
zero_count = (X == 0).sum()
total_values = X.size

print(f"  Data quality:")
print(f"    NaN values: {nan_count:,} / {total_values:,} ({100*nan_count/total_values:.2f}%)")
print(f"    Zero values: {zero_count:,} / {total_values:,} ({100*zero_count/total_values:.2f}%)")
print(f"    Non-zero, non-NaN: {total_values - nan_count - zero_count:,} ({100*(total_values - nan_count - zero_count)/total_values:.2f}%)")

# Print sample values
print(f"\n  Sample of first 3 rows:")
for i in range(min(3, len(X))):
    print(f"    Row {i}: [{X[i,0]:.3f}, {X[i,1]:.3f}, {X[i,2]:.3f}, ..., {X[i,-1]:.3f}]")

# Statistics
if nan_count < total_values:
    valid_X = X[~np.isnan(X)]
    print(f"\n  Statistics (non-NaN values):")
    print(f"    Min: {valid_X.min():.3f}")
    print(f"    Max: {valid_X.max():.3f}")
    print(f"    Mean: {valid_X.mean():.3f}")
    print(f"    Std: {valid_X.std():.3f}")

# ERROR CHECK
if nan_count > total_values * 0.99:
    print("\n  ⚠⚠⚠ CRITICAL ERROR: >99% of data is NaN! Extraction failed.")
elif zero_count > total_values * 0.99:
    print("\n  ⚠⚠⚠ CRITICAL ERROR: >99% of data is zero! Extraction failed.")
else:
    print("\n  ✓ Data appears valid!")

## Step 7: Save Dataset

In [None]:
print("\n[7/7] Saving dataset...")

output_file = 'wetland_dataset_1.5M.npz'

# Only save samples that were found
X_final = X[found_samples]
y_final = y[found_samples]

np.savez_compressed(
    output_file,
    X=X_final,
    y=y_final,
    class_weights=class_weights.numpy(),
)

file_size = os.path.getsize(output_file)
print(f"  ✓ Saved: {output_file}")
print(f"  ✓ File size: {file_size / (1024**2):.1f} MB")
print(f"  ✓ Samples saved: {len(X_final):,}")

# VALIDATION: Reload and check
print("\n  Validating saved file...")
test_data = np.load(output_file)
print(f"    Keys: {test_data.files}")
print(f"    X shape: {test_data['X'].shape}")
print(f"    y shape: {test_data['y'].shape}")
print(f"    Sample X values: {test_data['X'][0, :5]}")

if np.all(test_data['X'] == 0) or np.all(np.isnan(test_data['X'])):
    print("\n  ⚠⚠⚠ ERROR: Saved data is all zeros/NaN!")
else:
    print("\n  ✓ Saved data looks good!")

test_data.close()

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)
print(f"Download: {output_file}")
print(f"Expected size: >100 MB (not 4-5 MB!)")