In [1]:
import os 
from uvars import f1,f2,f3,f4,f5

In [None]:
import os
import time
import rasterio
import numpy as np
import joblib

# Scikit-learn and related models
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor
import lightgbm as lgb
import xgboost as xgb
import catboost as ctb

# --- 1. Helper Functions ---

def read_rasters_to_numpy(raster_paths, target_raster_path):
    """
    Reads a list of rasters, flattens them, and stacks them into a NumPy array.

    Args:
        raster_paths (list): A list of file paths to the raster files.
        target_raster_path (str): The path to the specific raster we intend to impute.
                                  Its metadata (profile, shape) will be returned.

    Returns:
        tuple: A tuple containing:
            - np.ndarray: A 2D array where each column is a flattened raster.
            - list: A list of column names (the original raster paths).
            - dict: The rasterio profile of the target raster.
            - tuple: The original (height, width) shape of the target raster.
    """
    raster_data_list = []
    column_names = []
    target_profile = None
    target_shape = None

    print("\nReading raster data...")
    for path in raster_paths:
        with rasterio.open(path) as src:
            if path == target_raster_path:
                target_profile = src.profile
                target_shape = src.shape
                # Ensure nodata value is converted to np.nan for imputation
                nodata_val = src.nodata
            
            band_data = src.read(1).flatten().astype(np.float32)

            # Convert the raster's specific nodata value to NaN
            if nodata_val is not None:
                band_data[band_data == nodata_val] = np.nan
            
            raster_data_list.append(band_data)
            column_names.append(os.path.basename(path))
            print(f"Read '{os.path.basename(path)}' with shape {src.shape}. Nodata value: {nodata_val}")

    # Stack the 1D arrays as columns in a 2D NumPy array
    return np.stack(raster_data_list, axis=1), column_names, target_profile, target_shape


def get_estimator(method: str):
    """
    Returns a configured scikit-learn compatible estimator based on the method name.
    Estimators are configured with parameters for better performance on large datasets.
    """
    print(f"Initializing estimator for method: '{method}'")
    if method == 'rf':
        # RandomForest: Good all-rounder, but can be memory intensive.
        return RandomForestRegressor(
            n_estimators=100,
            min_samples_leaf=10, # Regularization to prevent overfitting
            n_jobs=-1,           # Use all available CPU cores
            random_state=42
        )
    elif method == 'lgbm':
        # LightGBM: Fast and memory efficient.
        return lgb.LGBMRegressor(
            n_estimators=1000,
            n_jobs=-1,
            learning_rate=0.05,
            num_leaves=31,
            random_state=42
        )
    elif method == 'xgb':
         # XGBoost: Very powerful, 'hist' tree method is fast for large data.
        return xgb.XGBRegressor(
            n_estimators=1000,
            n_jobs=-1,
            tree_method='hist', # Much faster than exact method
            eta=0.05,           # Learning rate
            random_state=42
        )
    elif method == 'ctb':
        # CatBoost: Robust, handles categorical features well (not used here but good practice).
        return ctb.CatBoostRegressor(
            n_estimators=1000,
            thread_count=-1,
            verbose=0,          # Suppress verbose output during training
            random_state=42
        )
    elif method == 'mice':
        # Default MICE (BayesianRidge). Fast but potentially less accurate than tree models.
        return None # IterativeImputer's default
    else:
        raise ValueError(f"Unknown method: '{method}'. Choose from 'rf', 'lgbm', 'xgb', 'ctb', 'mice', 'ensemble'.")


def write_raster(output_path, data_1d, profile, shape):
    """Writes a 1D NumPy array to a GeoTIFF raster file."""
    print(f"Writing filled raster to '{output_path}'...")
    try:
        # Reshape the 1D array back to its original 2D raster shape
        reshaped_data = data_1d.reshape(shape)

        # Update profile to ensure data type matches
        profile.update(dtype=reshaped_data.dtype.name)

        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(reshaped_data, 1)
        print("Write complete.")
    except Exception as e:
        print(f"Error writing raster to {output_path}: {e}")

# --- 2. Main Imputation Logic ---

def run_imputation_process(raster_paths, target_raster_path, output_dir, method='ensemble', save_policy='last', save_imputer_model=True):
    """
    Main orchestration function for the raster imputation process.

    Args:
        raster_paths (list): Paths to all input rasters (features + target).
        target_raster_path (str): Path to the specific raster to be filled.
        output_dir (str): Directory to save output files.
        method (str): Imputation method ('rf', 'lgbm', 'xgb', 'ctb', 'mice', 'ensemble').
        save_policy (str): 'all' to save intermediate results, 'last' to save only the final result.
        save_imputer_model (bool): If True, saves the fitted imputer object for later use.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # --- Read Data ---
    data_np, col_names, profile, shape = read_rasters_to_numpy(raster_paths, target_raster_path)
    target_col_index = raster_paths.index(target_raster_path)
    
    # Report initial nulls
    initial_nulls = np.isnan(data_np[:, target_col_index]).sum()
    print(f"\nNumber of nulls in '{os.path.basename(target_raster_path)}' before imputation: {initial_nulls}")
    if initial_nulls == 0:
        print("No missing values to impute. Exiting.")
        return

    # --- Imputation ---
    start_time = time.time()
    
    base_target_name = os.path.splitext(os.path.basename(target_raster_path))[0]
    
    if method == 'ensemble':
        imputed_target_cols = []
        base_methods = ['lgbm', 'xgb', 'ctb']#'rf',
        
        for m in base_methods:
            print(f"\n--- Running Ensemble Component: {m.upper()} ---")
            estimator = get_estimator(m)
            imputer = IterativeImputer(estimator=estimator, max_iter=10, random_state=42, verbose=2)
            
            imputed_data = imputer.fit_transform(data_np)
            imputed_target_cols.append(imputed_data[:, target_col_index])
            
            if save_policy == 'all':
                output_path = os.path.join(output_dir, f"{base_target_name}_filled_{m}.tif")
                write_raster(output_path, imputed_data[:, target_col_index], profile, shape)

        # Average the results from all models
        print("\nAveraging results for ensemble...")
        final_imputed_target = np.mean(np.stack(imputed_target_cols, axis=1), axis=1)

    else: # Single model logic
        estimator = get_estimator(method)
        imputer = IterativeImputer(estimator=estimator, max_iter=10, random_state=42, verbose=2)
        imputed_data = imputer.fit_transform(data_np)
        final_imputed_target = imputed_data[:, target_col_index]

        if save_imputer_model:
            imputer_path = os.path.join(output_dir, f"imputer_model_{method}.joblib")
            print(f"\nSaving fitted imputer to '{imputer_path}'...")
            joblib.dump(imputer, imputer_path)


    # --- Write Final Output ---
    final_output_path = os.path.join(output_dir, f"{base_target_name}_filled_{method}.tif")
    write_raster(final_output_path, final_imputed_target, profile, shape)

    end_time = time.time()
    print("\n--- Process complete. ---")
    print(f"Total time taken: {end_time - start_time:.2f} seconds.")
    print(f"The final imputed raster has been saved as '{final_output_path}'.")


# --- 3. Example Usage ---

if __name__ == "__main__":

    # --- Run the Imputation ---
    
    # Define paths and output directory
    all_raster_paths = [f1, f2, f3, f4,f5]
    target_to_fill = f1
    output_directory = 'output_filled_rasters'
    os.makedirs(output_directory, exist_ok=True)
    
    # === SCENARIO 1: Run a single powerful model (e.g., XGBoost) ===
    # run_imputation_process(
    #     raster_paths=all_raster_paths,
    #     target_raster_path=target_to_fill,
    #     output_dir=output_directory,
    #     method='xgb', # Choose from 'rf', 'lgbm', 'xgb', 'ctb', 'mice'
    #     save_policy='last',
    #     save_imputer_model=True
    # )

    # === SCENARIO 2: Run the robust ENSEMBLE model ===
    run_imputation_process(
        raster_paths=all_raster_paths,
        target_raster_path=target_to_fill,
        output_dir=output_directory,
        method='ensemble',
        save_policy='all', # 'all' saves rf, lgbm, etc. outputs; 'last' saves only ensemble.
        save_imputer_model=False # Saving an ensemble imputer is not applicable here
    )

    # --- Example of Loading a Saved Imputer ---
   


Reading raster data...
Read 'N10E105_tdem_dem_egm_v_3_gap.tif' with shape (9001, 9001). Nodata value: -9999.0
Read 'N10E105_esawc_x.tif' with shape (9001, 9001). Nodata value: -9999.0
Read 'N10E105_tdem_hem.tif' with shape (9001, 9001). Nodata value: -9999.0
Read 'N10E105_s1.tif' with shape (9001, 9001). Nodata value: -9999.0
Read 'N10E105_tdem_dem_egm.tif' with shape (9001, 9001). Nodata value: -9999.0

Number of nulls in 'N10E105_tdem_dem_egm_v_3_gap.tif' before imputation: 6149456

--- Running Ensemble Component: LGBM ---
Initializing estimator for method: 'lgbm'
[IterativeImputer] Completing matrix with shape (81018001, 5)
[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.308736 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 1020
[LightGBM] [Info] Number of data points in the train set: 81018001, number of used features: 4
[LightGBM

In [None]:
# it took 9h to run all of the in CPU