In [None]:
# --- STRONG OPERATIONAL BASELINE: RANDOM FOREST WITH SPATIAL FEATURES ---
# This script implements a robust statistical downscaling model.
# It trains a single, powerful Random Forest model on features
# extracted from spatial neighborhoods, providing a strong and scientifically
# valid baseline for comparison against the GAN.

import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import joblib
from scipy.ndimage import uniform_filter, sobel
import warnings
from scipy.ndimage import zoom
import gc

from sklearn.ensemble import RandomForestRegressor

warnings.filterwarnings('ignore')

# --- CONFIGURATION ---
try:
    PROJECT_PATH = Path('/content/drive/MyDrive/AR_Downscaling')
except:
    PROJECT_PATH = Path('.') # For local execution

DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'publication_experiments' / 'strong_baseline_rf'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

TARGET_SHAPE = (256, 256)

# --- FEATURE ENGINEERING ---

def extract_chunk_features(X_chunk, window_size=5):
    """Extract features for a chunk of data."""
    num_samples, num_channels, h, w = X_chunk.shape
    num_features = num_channels * 3

    # Reshape for easier processing: (sample, pixel, channel)
    X_pixel_major = X_chunk.transpose(0, 2, 3, 1).reshape(num_samples * h * w, num_channels)

    # Initialize the feature matrix
    features = np.zeros((num_samples * h * w, num_features), dtype=np.float32)

    # Feature 1: The original interpolated values
    features[:, 0:num_channels] = X_pixel_major

    # Features 2 & 3: Local mean and standard deviation
    for c in range(num_channels):
        # Calculate local mean for the entire stack of images
        local_mean = uniform_filter(X_chunk[:, c], size=window_size)

        # Calculate local variance, then std dev
        local_sq_mean = uniform_filter(X_chunk[:, c]**2, size=window_size)
        local_var = local_sq_mean - local_mean**2
        local_std = np.sqrt(np.maximum(local_var, 0)) # Ensure non-negative

        # Assign to the feature matrix
        features[:, num_channels + c] = local_mean.flatten()
        features[:, (2 * num_channels) + c] = local_std.flatten()

    return features

def extract_neighborhood_features(X_interp, window_size=5):
    """
    Extracts rich spatial features from the interpolated predictor grid.
    Processes in chunks to manage memory efficiently.
    """
    print("Extracting spatial neighborhood features...")
    num_samples, num_channels, h, w = X_interp.shape

    # Process in chunks to avoid memory issues
    chunk_size = min(8, num_samples)  # Process 8 samples at a time
    all_features = []

    for start_idx in tqdm(range(0, num_samples, chunk_size), desc="Processing feature chunks"):
        end_idx = min(start_idx + chunk_size, num_samples)
        chunk = X_interp[start_idx:end_idx]

        # Extract features for this chunk
        chunk_features = extract_chunk_features(chunk, window_size)
        all_features.append(chunk_features)

        # Clean up memory
        del chunk
        gc.collect()

    print("Concatenating all features...")
    features = np.vstack(all_features)

    # Clean up
    del all_features
    gc.collect()

    return features

# --- DATA PREPARATION ---

def load_and_prepare_data(split='train'):
    """Loads data, interpolates it, and extracts features."""
    print(f"\n--- Preparing '{split}' data ---")

    # Load raw data with robust cropping/padding
    split_dir = DATA_DIR / split
    predictor_files = sorted(list(split_dir.glob('*_predictor.npy')))
    num_samples = len(predictor_files)
    coarse_shape = np.load(predictor_files[0]).shape[1:]
    X_coarse = np.zeros((num_samples, 5, *coarse_shape), dtype=np.float32)
    Y_high_res = np.zeros((num_samples, *TARGET_SHAPE), dtype=np.float32)

    for i, pred_path in enumerate(tqdm(predictor_files, desc=f"Loading {split} files")):
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path)
        target_data = np.load(targ_path)

        h, w = target_data.shape
        th, tw = TARGET_SHAPE
        if h != th or w != tw:
            start_h = max(0, (h - th) // 2)
            start_w = max(0, (w - tw) // 2)
            target_data = target_data[start_h : start_h + th, start_w : start_w + tw]

        if target_data.shape != TARGET_SHAPE:
             padded_target = np.zeros(TARGET_SHAPE, dtype=np.float32)
             padded_target[:target_data.shape[0], :target_data.shape[1]] = target_data
             target_data = padded_target

        X_coarse[i] = predictor_data
        Y_high_res[i] = target_data

    # Interpolate predictors
    print("Interpolating coarse predictors...")
    X_interp = np.zeros((num_samples, 5, *TARGET_SHAPE), dtype=np.float32)
    zoom_factors = (TARGET_SHAPE[0] / coarse_shape[0], TARGET_SHAPE[1] / coarse_shape[1])
    for i in tqdm(range(num_samples), desc="Interpolating samples"):
        for c in range(5):
            X_interp[i, c] = zoom(X_coarse[i, c], zoom_factors, order=3)

    # Clean up coarse data
    del X_coarse
    gc.collect()

    # Extract spatial features
    X_features = extract_neighborhood_features(X_interp)

    # Clean up interpolated data
    del X_interp
    gc.collect()

    # Flatten target variable
    Y_flat = Y_high_res.flatten()

    # Clean up target data
    del Y_high_res
    gc.collect()

    return X_features, Y_flat

# --- STRATIFIED SAMPLING FOR RIGOROUS TRAINING ---

def stratified_sample(X_features, Y_flat, sample_size=500000, random_state=42):
    """
    Performs stratified sampling to ensure representative training data
    while maintaining computational feasibility.
    """
    print(f"\n--- Performing Stratified Sampling ---")
    print(f"Original dataset size: {len(X_features):,} samples")

    np.random.seed(random_state)

    # Create value bins for stratification
    n_bins = 20
    percentiles = np.linspace(0, 100, n_bins + 1)
    bin_edges = np.percentile(Y_flat, percentiles)
    bin_indices = np.digitize(Y_flat, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)

    # Sample from each bin proportionally
    samples_per_bin = sample_size // n_bins
    sampled_indices = []

    for bin_idx in range(n_bins):
        bin_mask = bin_indices == bin_idx
        bin_size = np.sum(bin_mask)

        if bin_size > 0:
            bin_indices_list = np.where(bin_mask)[0]
            n_samples_bin = min(samples_per_bin, bin_size)
            sampled_bin_indices = np.random.choice(bin_indices_list, n_samples_bin, replace=False)
            sampled_indices.extend(sampled_bin_indices)

    sampled_indices = np.array(sampled_indices)
    print(f"Sampled dataset size: {len(sampled_indices):,} samples")

    return X_features[sampled_indices], Y_flat[sampled_indices]

# --- MODEL TRAINING ---

def train_rf_model(X_train_features, Y_train_flat):
    """Trains a single Random Forest model on the entire domain."""
    print("\n--- Training Strong Baseline: Random Forest ---")

    # Perform stratified sampling for computational feasibility while maintaining rigor
    X_sampled, Y_sampled = stratified_sample(X_train_features, Y_train_flat,
                                           sample_size=500000, random_state=42)

    # Clean up original data
    del X_train_features, Y_train_flat
    gc.collect()

    # Define a powerful and rigorous model
    rf_model = RandomForestRegressor(
        n_estimators=100,      # Sufficient trees for robust performance
        max_depth=20,          # Deep enough to capture complex patterns
        min_samples_split=10,  # Prevent overfitting
        min_samples_leaf=5,    # Ensure leaf nodes have sufficient samples
        max_features='sqrt',   # Standard feature selection
        bootstrap=True,        # Enable bootstrap sampling
        n_jobs=-1,            # Use all available cores
        random_state=42,      # Reproducibility
        verbose=2             # Show training progress
    )

    print(f"Training Random Forest with {rf_model.n_estimators} trees...")
    print(f"Training data shape: {X_sampled.shape}")
    print(f"Target data shape: {Y_sampled.shape}")

    # Train the model
    rf_model.fit(X_sampled, Y_sampled)

    # Save the model
    model_path = OUTPUT_DIR / 'strong_baseline_rf.joblib'
    joblib.dump(rf_model, model_path)
    print(f"✅ Strong RF baseline saved to {model_path}")

    # Print feature importance summary
    feature_names = []
    for i in range(5):  # 5 channels
        feature_names.extend([f'var_{i}_value', f'var_{i}_mean', f'var_{i}_std'])

    importances = rf_model.feature_importances_
    top_features = np.argsort(importances)[-10:][::-1]

    print("\nTop 10 Most Important Features:")
    for idx in top_features:
        print(f"  {feature_names[idx]}: {importances[idx]:.4f}")

    return rf_model

# --- MAIN EXECUTION ---
def main():
    # 1. Prepare training data
    X_train_features, Y_train_flat = load_and_prepare_data('train')

    # 2. Train the model
    rf_model = train_rf_model(X_train_features, Y_train_flat)

if __name__ == "__main__":
    main()


--- Preparing 'train' data ---


Loading train files: 100%|██████████| 1200/1200 [00:33<00:00, 36.19it/s]


Interpolating coarse predictors...


Interpolating samples: 100%|██████████| 1200/1200 [00:51<00:00, 23.46it/s]


Extracting spatial neighborhood features...


Processing feature chunks: 100%|██████████| 150/150 [00:41<00:00,  3.59it/s]


Concatenating all features...

--- Training Strong Baseline: Random Forest ---

--- Performing Stratified Sampling ---
Original dataset size: 78,643,200 samples
Sampled dataset size: 500,000 samples
Training Random Forest with 100 trees...
Training data shape: (500000, 15)
Target data shape: (500000,)
building tree 1 of 100
building tree 2 of 100
building tree 3 of 100
building tree 4 of 100
building tree 5 of 100
building tree 6 of 100
building tree 7 of 100
building tree 8 of 100


[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 9 of 100
building tree 10 of 100
building tree 11 of 100
building tree 12 of 100
building tree 13 of 100
building tree 14 of 100
building tree 15 of 100
building tree 16 of 100
building tree 17 of 100
building tree 18 of 100
building tree 19 of 100
building tree 20 of 100
building tree 21 of 100
building tree 22 of 100
building tree 23 of 100
building tree 24 of 100
building tree 25 of 100
building tree 26 of 100
building tree 27 of 100
building tree 28 of 100
building tree 29 of 100
building tree 30 of 100
building tree 31 of 100
building tree 32 of 100
building tree 33 of 100
building tree 34 of 100


[Parallel(n_jobs=-1)]: Done  25 tasks      | elapsed:   18.4s


building tree 35 of 100
building tree 36 of 100
building tree 37 of 100
building tree 38 of 100
building tree 39 of 100
building tree 40 of 100
building tree 41 of 100
building tree 42 of 100
building tree 43 of 100
building tree 44 of 100
building tree 45 of 100
building tree 46 of 100
building tree 47 of 100
building tree 48 of 100
building tree 49 of 100
building tree 50 of 100
building tree 51 of 100
building tree 52 of 100
building tree 53 of 100
building tree 54 of 100
building tree 55 of 100
building tree 56 of 100
building tree 57 of 100
building tree 58 of 100
building tree 59 of 100
building tree 60 of 100
building tree 61 of 100
building tree 62 of 100
building tree 63 of 100
building tree 64 of 100
building tree 65 of 100
building tree 66 of 100
building tree 67 of 100
building tree 68 of 100
building tree 69 of 100
building tree 70 of 100
building tree 71 of 100
building tree 72 of 100
building tree 73 of 100
building tree 74 of 100
building tree 75 of 100
building tree 76

[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:   58.9s finished


✅ Strong RF baseline saved to /content/drive/MyDrive/AR_Downscaling/publication_experiments/strong_baseline_rf/strong_baseline_rf.joblib

Top 10 Most Important Features:
  var_2_std: 0.0917
  var_2_value: 0.0878
  var_2_mean: 0.0828
  var_0_mean: 0.0812
  var_0_std: 0.0741
  var_1_value: 0.0729
  var_1_std: 0.0656
  var_3_mean: 0.0598
  var_0_value: 0.0597
  var_4_mean: 0.0576
