In [2]:
# -*- coding: utf-8 -*-
"""Enhanced salinity prediction pipeline with model comparison and SHAP analysis.
This script is adapted from ``new_code_codex_v2.ipynb`` and implements
additional features requested in ``README.md``.

The pipeline:
1. Load remote sensing imagery and sample soil data.
2. Train multiple regression models.
3. Evaluate models using R2, RMSE and MAE.
4. Save predictions from each model to CSV.
5. Generate comparison plots and SHAP summaries when possible.
6. Export the best model's prediction image via ``geemap.ee_export_image_to_drive``.

The script is designed to be run as a standalone Python module after
Google Earth Engine authentication.
"""

import os
from typing import Dict, Tuple

import ee
import geemap
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import GridSearchCV, KFold, train_test_split
from sklearn.svm import SVR
import xgboost as xgb

try:
    import shap  # optional, used for explanation
    SHAP_AVAILABLE = True
except Exception:  # pragma: no cover - shap is optional
    SHAP_AVAILABLE = False

import matplotlib.pyplot as plt

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
START_DATE = "2022-07-01"
END_DATE = "2022-10-01"
BOUNDARY_PATH = "/path/to/bdy.shp"
SOIL_DATA_PATH = "/path/to/soil_samples.csv"
OUTPUT_DIR = "outputs"

LANDSAT_COLLECTION = "LANDSAT/LC08/C02/T1_L2"
SENTINEL1_COLLECTION = "COPERNICUS/S1_GRD"
MODIS_ET_COLLECTION = "MODIS/061/MOD16A2"
CHIRPS_COLLECTION = "UCSB-CHG/CHIRPS/DAILY"
SRTM = "USGS/SRTMGL1_003"
WORLDCOVER = "ESA/WorldCover/v200"

CLOUD_THRESHOLD = 70


# ---------------------------------------------------------------------------
# Utility functions
# ---------------------------------------------------------------------------

def initialize_gee() -> None:
    """Authenticate and initialize Google Earth Engine."""
    try:
        ee.Initialize()
    except Exception:
        ee.Authenticate()
        ee.Initialize()


def load_boundary(path: str) -> ee.FeatureCollection:
    return geemap.shp_to_ee(path)


def mask_clouds_l8(image: ee.Image) -> ee.Image:
    qa = image.select("QA_PIXEL")
    cloud = qa.bitwiseAnd(1 << 3).eq(0)
    shadow = qa.bitwiseAnd(1 << 4).eq(0)
    snow = qa.bitwiseAnd(1 << 5).eq(0)
    mask = cloud.And(shadow).And(snow)
    return image.updateMask(mask)


def apply_scale(image: ee.Image) -> ee.Image:
    optical = image.select("SR_B.*").multiply(0.0000275).add(-0.2)
    thermal = image.select("ST_B.*").multiply(0.00341802).add(149.0)
    return image.addBands(optical, None, True).addBands(thermal, None, True)


def process_landsat(boundary: ee.FeatureCollection) -> ee.Image:
    collection = (
        ee.ImageCollection(LANDSAT_COLLECTION)
        .filterDate(START_DATE, END_DATE)
        .filterBounds(boundary)
        .map(apply_scale)
        .map(mask_clouds_l8)
    )
    composite = collection.median().clip(boundary)
    ndvi = composite.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")
    si = composite.select("SR_B2").multiply(composite.select("SR_B4")).sqrt().rename("SI1")
    return composite.addBands([ndvi, si])


def process_sentinel1(boundary: ee.FeatureCollection) -> ee.Image:
    def mask_edge(image: ee.Image) -> ee.Image:
        edge = image.lt(-30.0)
        mask = image.mask().And(edge.Not())
        return image.updateMask(mask)

    collection = (
        ee.ImageCollection(SENTINEL1_COLLECTION)
        .filterDate(START_DATE, END_DATE)
        .filterBounds(boundary)
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
        .filter(ee.Filter.eq("instrumentMode", "IW"))
        .map(mask_edge)
    )
    vv = collection.select("VV").median()
    vh = collection.select("VH").median()
    ratio = vv.subtract(vh).rename("VV_VH_diff")
    return ee.Image.cat([vv, vh, ratio]).clip(boundary)


def process_environment(boundary: ee.FeatureCollection) -> ee.Image:
    et = (
        ee.ImageCollection(MODIS_ET_COLLECTION)
        .filterDate(START_DATE, END_DATE)
        .filterBounds(boundary)
        .select("ET")
        .mean()
        .rename("ET")
    )
    precip = (
        ee.ImageCollection(CHIRPS_COLLECTION)
        .filterDate(START_DATE, END_DATE)
        .filterBounds(boundary)
        .select("precipitation")
        .sum()
        .rename("Precip")
    )
    dem = ee.Image(SRTM)
    slope = ee.Terrain.slope(dem).rename("slope")
    return ee.Image.cat([et, precip, dem.rename("elevation"), slope]).clip(boundary)


def build_feature_stack(boundary: ee.FeatureCollection) -> ee.Image:
    l8 = process_landsat(boundary)
    s1 = process_sentinel1(boundary)
    env = process_environment(boundary)
    worldcover = ee.ImageCollection(WORLDCOVER).first().select("Map")
    return ee.Image.cat([l8, s1, env, worldcover.rename("landcover")])


def sample_points(image: ee.Image, boundary: ee.FeatureCollection, sample_path: str) -> pd.DataFrame:
    samples = geemap.shp_to_ee(sample_path)
    sample = image.sampleRegions(
        collection=samples,
        properties=["salinity"],
        scale=30,
        geometries=True,
    )
    df = geemap.ee_to_pandas(sample)
    return df.dropna()


# ---------------------------------------------------------------------------
# Model training and evaluation
# ---------------------------------------------------------------------------

def train_models(df: pd.DataFrame) -> Tuple[Dict[str, Dict], pd.DataFrame, pd.Series]:
    X = df.drop(columns=["salinity", "longitude", "latitude"], errors="ignore")
    y = df["salinity"]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42
    )

    models = {
        "RandomForest": RandomForestRegressor(random_state=42),
        "SVR": SVR(),
        "Linear": LinearRegression(),
        "XGB": xgb.XGBRegressor(objective="reg:squarederror", random_state=42),
    }

    params = {
        "RandomForest": {"n_estimators": [100, 200], "max_depth": [5, 10]},
        "SVR": {"C": [1, 10], "gamma": ["scale", "auto"]},
        "Linear": {},
        "XGB": {"n_estimators": [100, 200], "max_depth": [3, 6]},
    }

    results = {}
    metrics_list = []

    for name, model in models.items():
        grid = GridSearchCV(
            model, params[name], cv=KFold(n_splits=5, shuffle=True, random_state=42)
        )
        grid.fit(X_train, y_train)
        pred = grid.predict(X_test)

        r2 = r2_score(y_test, pred)
        rmse = mean_squared_error(y_test, pred, squared=False)
        mae = mean_absolute_error(y_test, pred)

        metrics_list.append({"model": name, "R2": r2, "RMSE": rmse, "MAE": mae})
        results[name] = {"model": grid.best_estimator_, "pred": pred}

        # save predictions
        pred_df = pd.DataFrame({"actual": y_test.values, "pred": pred})
        pred_df.to_csv(os.path.join(OUTPUT_DIR, f"{name}_predictions.csv"), index=False)

        # feature importance
        if hasattr(grid.best_estimator_, "feature_importances_"):
            importances = grid.best_estimator_.feature_importances_
            imp_df = pd.DataFrame({"feature": X_train.columns, "importance": importances})
            imp_df.sort_values("importance", ascending=False).to_csv(
                os.path.join(OUTPUT_DIR, f"{name}_feature_importance.csv"), index=False
            )

        # SHAP explanation
        if SHAP_AVAILABLE:
            explainer = shap.Explainer(grid.best_estimator_, X_train)
            shap_values = explainer(X_test)
            shap_df = pd.DataFrame(shap_values.values, columns=X_train.columns)
            shap_df.to_csv(os.path.join(OUTPUT_DIR, f"{name}_shap_values.csv"), index=False)
            plt.figure()
            shap.summary_plot(shap_values, X_test, show=False)
            plt.tight_layout()
            plt.savefig(os.path.join(OUTPUT_DIR, f"{name}_shap_summary.png"))
            plt.close()

    metrics_df = pd.DataFrame(metrics_list)
    metrics_df.to_csv(os.path.join(OUTPUT_DIR, "model_performance.csv"), index=False)
    return results, X_test, y_test


# ---------------------------------------------------------------------------
# Visualization utilities
# ---------------------------------------------------------------------------

def plot_model_comparison(metrics_df: pd.DataFrame) -> None:
    plt.figure(figsize=(8, 4))
    metrics_df.set_index("model")[["R2", "RMSE", "MAE"]].plot(kind="bar")
    plt.ylabel("Metric value")
    plt.title("Model performance comparison")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "model_comparison.png"))
    plt.close()


# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------

def main() -> None:
    initialize_gee()
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    boundary = load_boundary(BOUNDARY_PATH)
    feature_image = build_feature_stack(boundary)

    df = sample_points(feature_image, boundary, SOIL_DATA_PATH)
    df.to_csv(os.path.join(OUTPUT_DIR, "training_samples.csv"), index=False)

    results, X_test, y_test = train_models(df)

    metrics_df = pd.read_csv(os.path.join(OUTPUT_DIR, "model_performance.csv"))
    plot_model_comparison(metrics_df)

    best_name = metrics_df.sort_values("R2", ascending=False).iloc[0]["model"]
    best_model = results[best_name]["model"]

    print(f"Best model: {best_name}")

    # Predict across image and export
    feature_bands = feature_image.bandNames()
    predictors = feature_image.select(feature_bands)
    model_img = geemap.sk_export_model(best_model, predictors)

    geemap.ee_export_image_to_drive(
        model_img,
        description="salinity_prediction",
        folder="gee_outputs",
        region=boundary.geometry(),
        scale=30,
        fileFormat="GeoTIFF",
    )
    print("Export task started.")


if __name__ == "__main__":  # pragma: no cover - script entry
    main()

The input shapefile is invalid.


AttributeError: module 'geemap' has no attribute 'ee_to_pandas'