In [None]:
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce4f94a6-2a4f-4ceb-b2fb-3d3e38a2dd20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import ee\n",
    "\n",
    "import geemap\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from datetime import datetime\n",
    "\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV, KFold\n",
    "\n",
    "from sklearn.metrics import r2_score, mean_squared_error\n",
    "\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "\n",
    "from sklearn.svm import SVR\n",
    "\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "import xgboost as xgb\n",
    "\n",
    "try:\n",
    "    import shap\n",
    "    SHAP_AVAILABLE = True\n",
    "except ImportError:\n",
    "    SHAP_AVAILABLE = False\n",
    "\n",
    "\n",
    "\n",
    "# ----------------------------------------------------------------------\n",
    "\n",
    "# Configuration\n",
    "\n",
    "# ----------------------------------------------------------------------\n",
    "\n",
    "START_DATE = '2022-07-01'\n",
    "\n",
    "END_DATE = '2022-10-01'\n",
    "BOUNDARY_PATH = '/path/to/bdy.shp'\n",
    "SOIL_DATA_PATH = '/path/to/soil_samples.csv'\n",
    "\n",
    "BOUNDARY_PATH = '/Users/hanxu/geemap/bdy.shp'\n",
    "\n",
    "SOIL_DATA_PATH = '/Users/hanxu/geemap/material/soil/soil_2022_08.csv'\n",
    "\n",
    "OUTPUT_DIR = 'outputs'\n",
    "\n",
    "\n",
    "\n",
    "LANDSAT_COLLECTION = 'LANDSAT/LC08/C02/T1_L2'\n",
    "\n",
    "SENTINEL1_COLLECTION = 'COPERNICUS/S1_GRD'\n",
    "\n",
    "MODIS_ET_COLLECTION = 'MODIS/061/MOD16A2'\n",
    "\n",
    "CHIRPS_COLLECTION = 'UCSB-CHG/CHIRPS/DAILY'\n",
    "\n",
    "SRTM = 'USGS/SRTMGL1_003'\n",
    "\n",
    "WORLDCOVER = 'ESA/WorldCover/v200'\n",
    "\n",
    "\n",
    "\n",
    "CLOUD_THRESHOLD = 70\n",
    "\n",
    "\n",
    "\n",
    "# ----------------------------------------------------------------------\n",
    "\n",
    "# Helper functions\n",
    "\n",
    "# ----------------------------------------------------------------------\n",
    "\n",
    "\n",
    "\n",
    "def initialize_gee():\n",
    "\n",
    "    \"\"\"Authenticate and initialize Google Earth Engine.\"\"\"\n",
    "\n",
    "    try:\n",
    "\n",
    "        ee.Initialize()\n",
    "\n",
    "    except Exception:\n",
    "\n",
    "        ee.Authenticate()\n",
    "\n",
    "        ee.Initialize()\n",
    "\n",
    "\n",
    "\n",
    "def load_boundary(path):\n",
    "\n",
    "    return geemap.shp_to_ee(path)\n",
    "\n",
    "\n",
    "\n",
    "def mask_clouds_l8(image):\n",
    "\n",
    "    qa = image.select('QA_PIXEL')\n",
    "\n",
    "    cloud = qa.bitwiseAnd(1 << 3).eq(0)\n",
    "\n",
    "    shadow = qa.bitwiseAnd(1 << 4).eq(0)\n",
    "\n",
    "    snow = qa.bitwiseAnd(1 << 5).eq(0)\n",
    "\n",
    "    mask = cloud.And(shadow).And(snow)\n",
    "\n",
    "    return image.updateMask(mask)\n",
    "\n",
    "\n",
    "\n",
    "def apply_scale(image):\n",
    "\n",
    "    optical = image.select('SR_B.*').multiply(0.0000275).add(-0.2)\n",
    "\n",
    "    thermal = image.select('ST_B.*').multiply(0.00341802).add(149.0)\n",
    "\n",
    "    return image.addBands(optical, None, True).addBands(thermal, None, True)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def process_landsat(boundary):\n",
    "\n",
    "    collection = (\n",
    "\n",
    "        ee.ImageCollection(LANDSAT_COLLECTION)\n",
    "\n",
    "        .filterDate(START_DATE, END_DATE)\n",
    "\n",
    "        .filterBounds(boundary)\n",
    "\n",
    "        .map(apply_scale)\n",
    "\n",
    "        .map(mask_clouds_l8)\n",
    "\n",
    "    )\n",
    "\n",
    "    composite = collection.median().clip(boundary)\n",
    "\n",
    "    ndvi = composite.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')\n",
    "\n",
    "    si = composite.select('SR_B2').multiply(composite.select('SR_B4')).sqrt().rename('SI1')\n",
    "\n",
    "    return composite.addBands([ndvi, si])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def process_sentinel1(boundary):\n",
    "\n",
    "    def mask_edge(image):\n",
    "\n",
    "        edge = image.lt(-30.0)\n",
    "\n",
    "        mask = image.mask().And(edge.Not())\n",
    "\n",
    "        return image.updateMask(mask)\n",
    "\n",
    "\n",
    "\n",
    "    collection = (\n",
    "\n",
    "        ee.ImageCollection(SENTINEL1_COLLECTION)\n",
    "\n",
    "        .filterDate(START_DATE, END_DATE)\n",
    "\n",
    "        .filterBounds(boundary)\n",
    "\n",
    "        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))\n",
    "\n",
    "        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))\n",
    "\n",
    "        .filter(ee.Filter.eq('instrumentMode', 'IW'))\n",
    "\n",
    "        .map(mask_edge)\n",
    "\n",
    "    )\n",
    "\n",
    "    vv = collection.select('VV').median()\n",
    "\n",
    "    vh = collection.select('VH').median()\n",
    "\n",
    "    ratio = vv.subtract(vh).rename('VV_VH_diff')\n",
    "\n",
    "    return ee.Image.cat([vv, vh, ratio]).clip(boundary)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def process_environment(boundary):\n",
    "\n",
    "    et = (\n",
    "\n",
    "        ee.ImageCollection(MODIS_ET_COLLECTION)\n",
    "\n",
    "        .filterDate(START_DATE, END_DATE)\n",
    "\n",
    "        .filterBounds(boundary)\n",
    "\n",
    "        .select('ET')\n",
    "\n",
    "        .mean()\n",
    "\n",
    "        .rename('ET')\n",
    "\n",
    "    )\n",
    "\n",
    "    precip = (\n",
    "\n",
    "        ee.ImageCollection(CHIRPS_COLLECTION)\n",
    "\n",
    "        .filterDate(START_DATE, END_DATE)\n",
    "\n",
    "        .filterBounds(boundary)\n",
    "\n",
    "        .select('precipitation')\n",
    "\n",
    "        .sum()\n",
    "\n",
    "        .rename('Precip')\n",
    "\n",
    "    )\n",
    "\n",
    "    dem = ee.Image(SRTM)\n",
    "\n",
    "    slope = ee.Terrain.slope(dem).rename('slope')\n",
    "\n",
    "    return ee.Image.cat([et, precip, dem.rename('elevation'), slope]).clip(boundary)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def build_feature_stack(boundary):\n",
    "\n",
    "    l8 = process_landsat(boundary)\n",
    "\n",
    "    s1 = process_sentinel1(boundary)\n",
    "\n",
    "    env = process_environment(boundary)\n",
    "\n",
    "    worldcover = ee.ImageCollection(WORLDCOVER).first().select('Map')\n",
    "\n",
    "    return ee.Image.cat([l8, s1, env, worldcover.rename('landcover')])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def sample_points(image, boundary, sample_path):\n",
    "\n",
    "    samples = geemap.shp_to_ee(sample_path)\n",
    "\n",
    "    sample = image.sampleRegions(\n",
    "\n",
    "        collection=samples,\n",
    "\n",
    "        properties=['salinity'],\n",
    "\n",
    "        scale=30,\n",
    "\n",
    "        geometries=True\n",
    "\n",
    "    )\n",
    "\n",
    "    df = geemap.ee_to_pandas(sample)\n",
    "\n",
    "    return df.dropna()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def train_models(df):\n",
    "\n",
    "    X = df.drop(columns=['salinity', 'longitude', 'latitude'], errors='ignore')\n",
    "\n",
    "    y = df['salinity']\n",
    "\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)\n",
    "\n",
    "\n",
    "\n",
    "    models = {\n",
    "\n",
    "        'RandomForest': RandomForestRegressor(random_state=42),\n",
    "\n",
    "        'SVR': SVR(),\n",
    "\n",
    "        'Linear': LinearRegression(),\n",
    "\n",
    "        'XGB': xgb.XGBRegressor(objective='reg:squarederror', random_state=42)\n",
    "\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "    params = {\n",
    "\n",
    "        'RandomForest': {'n_estimators': [100, 200], 'max_depth': [5, 10]},\n",
    "\n",
    "        'SVR': {'C': [1, 10], 'gamma': ['scale', 'auto']},\n",
    "\n",
    "        'Linear': {},\n",
    "\n",
    "        'XGB': {'n_estimators': [100, 200], 'max_depth': [3, 6]}\n",
    "\n",
    "    }\n",
    "\n",
    "\n",
    "\n",
    "    metrics_list = []\n",
    "    results = {}\n",
    "    feature_names = X_train.columns\n",
    "\n",
    "    for name, model in models.items():\n",
    "        grid = GridSearchCV(model, params[name], cv=KFold(n_splits=5, shuffle=True, random_state=42))\n",
    "        grid.fit(X_train, y_train)\n",
    "        pred = grid.predict(X_test)\n",
    "        r2 = r2_score(y_test, pred)\n",
    "        rmse = mean_squared_error(y_test, pred, squared=False)\n",
    "        results[name] = {'model': grid.best_estimator_, 'r2': r2, 'rmse': rmse}\n",
    "        metrics_list.append({'model': name, 'R2': r2, 'RMSE': rmse})\n",
    "        print(f'{name}: R2={r2:.3f}, RMSE={rmse:.3f}')\n",
    "    return results\n",
    "\n",
    "        if hasattr(grid.best_estimator_, 'feature_importances_'):\n",
    "            importances = grid.best_estimator_.feature_importances_\n",
    "            imp_df = pd.DataFrame({'feature': feature_names, 'importance': importances})\n",
    "            imp_df = imp_df.sort_values('importance', ascending=False)\n",
    "            print(imp_df.head())\n",
    "            imp_df.to_csv(os.path.join(OUTPUT_DIR, f'{name}_feature_importance.csv'), index=False)\n",
    "        if SHAP_AVAILABLE:\n",
    "            explainer = shap.Explainer(grid.best_estimator_, X_train)\n",
    "            shap_values = explainer(X_test)\n",
    "            shap_df = pd.DataFrame(shap_values.values, columns=feature_names)\n",
    "            shap_df.to_csv(os.path.join(OUTPUT_DIR, f'{name}_shap_values.csv'), index=False)\n",
    "\n",
    "    metrics_df = pd.DataFrame(metrics_list)\n",
    "    metrics_df.to_csv(os.path.join(OUTPUT_DIR, 'model_performance.csv'), index=False)\n",
    "    return results, X_test, y_test\n",
    "def main():\n",
    "\n",
    "    initialize_gee()\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "\n",
    "    boundary = load_boundary(BOUNDARY_PATH)\n",
    "\n",
    "    feature_image = build_feature_stack(boundary)\n",
    "\n",
    "    df = sample_points(feature_image, boundary, SOIL_DATA_PATH)\n",
    "\n",
    "    df.to_csv(os.path.join(OUTPUT_DIR, 'training_samples.csv'), index=False)\n",
    "\n",
    "    results = train_models(df)\n",
    "\n",
    "\n",
    "    results, X_test, y_test = train_models(df)\n",
    "\n",
    "    best_name = max(results, key=lambda k: results[k]['r2'])\n",
    "    best_model = results[best_name]['model']\n",
    "\n",
    "    best_model = results[best_name]['model']\n",
    "    pred = best_model.predict(X_test)\n",
    "    pd.DataFrame({'y_true': y_test, 'y_pred': pred}).to_csv(\n",
    "        os.path.join(OUTPUT_DIR, f'{best_name}_predictions.csv'), index=False\n",
    ")\n",
    "    if SHAP_AVAILABLE:\n",
    "        explainer = shap.Explainer(best_model, X_test)\n",
    "        shap_values = explainer(X_test)\n",
    "        pd.DataFrame(shap_values.values, columns=X_test.columns).to_csv(\n",
    "            os.path.join(OUTPUT_DIR, f'{best_name}_shap.csv'), index=False\n",
    "        )\n",
    "    print(f'Best model: {best_name}')\n",
    "\n",
    "\n",
    "\n",
    "    # Predict across image\n",
    "\n",
    "    feature_bands = feature_image.bandNames()\n",
    "\n",
    "    predictors = feature_image.select(feature_bands)\n",
    "\n",
    "    model = geemap.sk_export_model(best_model, predictors)\n",
    "\n",
    "    # Export predictions (placeholder, requires geemap>=0.30)\n",
    "\n",
    "    task = geemap.ee_export_image_to_drive(\n",
    "\n",
    "        model,\n",
    "\n",
    "        description='salinity_prediction',\n",
    "\n",
    "        folder='gee_outputs',\n",
    "\n",
    "        region=boundary,\n",
    "\n",
    "        scale=30\n",
    "\n",
    "    )\n",
    "\n",
    "    print('Export task started.')\n",
    "\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
    "\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
}