In [None]:
# ***************************************************************************
#    PySTGEE: Python Implementation for Spatio-Temporal GEE Landslide Modeling
#
#        Original Project     : STGEE (JavaScript/GEE)
#        Original Begin       : 2022-04
#        Original Copyright   : (C) 2022 by Giacomo Titti and Gabriele Nicola Napoli
#        Original Contact     : giacomotitti@gmail.com
#
#        Python Refactoring   : 2025-11
#        Updated Copyright    : (C) 2025 by Gabriele Nicola Napoli and Giacomo Titti
#        Contacts             : gabrielenicolanapoli@gmail.com, giacomotitti@gmail.com
#
# ***************************************************************************
# ***************************************************************************
#    PySTGEE
#    Copyright (C) 2022 Giacomo Titti, Gabriele Nicola Napoli
#    Copyright (C) 2025 Gabriele Nicola Napoli, Giacomo Titti
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <https://www.gnu.org/licenses/>.
# ***************************************************************************

In [None]:
# @title
# --- CELL 1: USER CONFIGURATION ---

# 1. Earth Engine Project Configuration
EE_PROJECT = 'your-ee-project-id'  # Enter your GEE Project ID here

# 2. Earth Engine Asset Paths
# Polygons used for training (Slope Units)
polygons_asset = "projects/ee-your-username/assets/your_training_polygons"
# Point data containing landslide events and dates
points_asset = "projects/ee-your-username/assets/your_landslide_points"
# Polygons used for final prediction/forecasting
prediction_asset = "projects/ee-your-username/assets/your_prediction_polygons"

# 3. Data Column Settings
# The exact name of the column in 'points_asset' that contains the date
DATE_COLUMN = 'date_column_name'

# The exact name of the column identifying the landslide (e.g. 'id', 'objectid', 'type')
LANDSLIDE_COLUMN = 'landslide_id'

# 4. CSV Export Settings
# 'BEST_ONLY': CSV contains only the best rainfall column (Max AUC) + static predictors.
# 'ALL_DATA' : CSV contains ALL calculated rainfall columns (1-30 days).
CSV_EXPORT_MODE = 'BEST_ONLY'

# 5. Analysis Parameters
# The specific date to generate the final hazard map for
FORECAST_DATE_FIXED = 'YYYY-MM-DD'

# Static Morphological Predictors Examples (Rainfall will be added dynamically)
STATIC_PREDICTORS = [
    'slope_mean',          # Mean slope
    'elevation_mean',      # Mean elevation
    'curvature_mean',      # Terrain curvature
    'ndvi_mean',           # Vegetation index
    'lithology_class'      # Categorical lithology
]


# Rainfall Window Search Range (Days)
# The model will test every interval between MIN and MAX to find the best AUC
MIN_DAYS = 1
MAX_DAYS = 30

print(f"Configuration Saved.")
print(f"Project: {EE_PROJECT}")
print(f"CSV Mode: {CSV_EXPORT_MODE}")
print(f"Rainfall Optimization Range: {MIN_DAYS} to {MAX_DAYS} days.")

Configuration Saved.
Project: stgee-dataset
CSV Mode: BEST_ONLY
Rainfall Optimization Range: 1 to 30 days.


In [None]:
# @title
# --- CELL 2: IMPORTS AND SYSTEM SETUP ---
import ee
import geemap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import re
import concurrent.futures
import os
import json
from IPython.display import display, clear_output

# Machine Learning Imports
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict, StratifiedKFold
from sklearn.metrics import roc_curve, auc, confusion_matrix, f1_score, cohen_kappa_score, accuracy_score
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Earth Engine Initialization
# This block automatically detects if it's running as a WebApp (via Service Account)
# or in a standard environment (Colab/Local).
try:
    # 1. WEBAPP MODE: Check for Service Account credentials in environment variables
    if 'EE_SERVICE_ACC_JSON' in os.environ:
        print("Authenticating via Service Account for WebApp deployment...")
        # Parse the JSON credentials stored in the environment variable
        creds_json = os.environ['EE_SERVICE_ACC_JSON']
        creds_dict = json.loads(creds_json)

        # Initialize using Service Account Credentials
        credentials = ee.ServiceAccountCredentials(
            creds_dict['client_email'],
            key_data=creds_json
        )
        ee.Initialize(credentials, project=EE_PROJECT)
        print(f"Connected to Earth Engine via Service Account (Project: {EE_PROJECT}).")

    else:
        # 2. STANDARD MODE: Local or Google Colab environment
        try:
            # Attempt direct initialization (works if already authenticated)
            ee.Initialize(project=EE_PROJECT)
            print(f"Connected to Earth Engine (Project: {EE_PROJECT}).")
        except Exception:
            # Fallback to interactive authentication if needed
            print("Standard initialization failed. Attempting interactive authentication...")
            ee.Authenticate()
            ee.Initialize(project=EE_PROJECT)
            print(f"Authentication completed (Project: {EE_PROJECT}).")

except Exception as e:
    print(f"CRITICAL ERROR: Failed to initialize Earth Engine: {e}")

In [None]:
# @title
# --- CELL 3: DATA LOADING & PRE-PROCESSING ---

# Visualization Palettes (Static)
VIS_PALETTE = [
    '#006b0b', '#1b7b25', '#4e9956', '#dbeadd', '#ffffff',
    '#f0b2ae', '#eb958f', '#df564d', '#d10e00'
]
PALETTE_CONFUSION = ['#D10E00', '#DF564D', '#DBEADD', '#006B0B']

# Load Collections using variables from CELL 1
print(f"Loading assets...")
raw_polygons = ee.FeatureCollection(polygons_asset)
landPoints = ee.FeatureCollection(points_asset)
raw_prediction = ee.FeatureCollection(prediction_asset)

# Function to extract numeric IDs from polygon attributes (essential for mapping)
def add_numeric_id(feature):
    str_id = ee.String(feature.get('id'))
    # Remove non-numeric characters
    num_str = str_id.replace(r'[^0-9]', '', 'g')
    # Parse to number, default to 0 if empty
    num_val = ee.Algorithms.If(num_str.length().gt(0), ee.Number.parse(num_str), 0)
    return feature.set('NUM_ID', num_val)

# Apply ID extraction
predictors_polygons = raw_polygons.map(add_numeric_id)
prediction_area_shp = raw_prediction.map(add_numeric_id)

# Global App State to store data between button clicks
APP = {
    'df': None,            # The training dataframe
    'model': None,         # The trained Random Forest model
    'map': None,           # The map widget
    'best_window': None,   # The optimized rainfall window (e.g., 12 days)
    'final_predictors': [] # The final list of columns used (Static + Best Rain)
}

print("Assets configured and ready for analysis.")

In [None]:
# @title
# --- CELL 4: TRAINING ENGINE (DETERMINISTIC & OPTIMIZED) ---

def download_training_data_server_side(log_widget):
    """
    Downloads training data for all days between MIN_DAYS and MAX_DAYS.
    Includes SORTING to ensure deterministic results across runs.
    Uses 'DATE_COLUMN' from configuration to identify dates.
    """
    # Use the dynamic column name defined in Cell 1
    raw_dates = landPoints.aggregate_array(DATE_COLUMN).distinct().getInfo()
    dates_list = [str(d)[:10] for d in raw_dates]

    with log_widget:
        print(f"Event Dates found: {len(dates_list)}")
        print(f"Retrieving rainfall data for windows {MIN_DAYS}-{MAX_DAYS} days...")

    def process_date(date_str):
        try:
            d = ee.Date(date_str)
            gpm = ee.ImageCollection('JAXA/GPM_L3/GSMaP/v8/operational').select('hourlyPrecipRateGC')

            # OPTIMIZATION: List Comprehension
            rain_bands = [
                gpm.filterDate(d.advance(-i, 'day'), d).sum().unmask(0).rename(f'Rn{i}')
                for i in range(MIN_DAYS, MAX_DAYS + 1)
            ]

            combined = ee.Image.cat(rain_bands)

            # Filter points using the dynamic date column
            todays_points = landPoints.filter(ee.Filter.eq(DATE_COLUMN, date_str))

            def map_polygons(poly):
                count = todays_points.filterBounds(poly.geometry()).size()
                # We retain 'system:index' (id) for sorting later
                return poly.set({'P/A': ee.Algorithms.If(count.gt(0), 1, 0), 'date': date_str})

            labeled_polys = predictors_polygons.map(map_polygons)

            stats = combined.reduceRegions(
                collection=labeled_polys,
                reducer=ee.Reducer.mean().combine(ee.Reducer.stdDev(), sharedInputs=True),
                scale=1000, tileScale=16
            )

            df_day = geemap.ee_to_df(stats)
            if df_day.empty: return None

            rename_dict = {f'Rn{i}_{suffix}': f'Rn{i}_{m}'
                           for i in range(MIN_DAYS, MAX_DAYS + 1)
                           for suffix, m in [('mean', 'm'), ('stdDev', 's')]}

            df_day = df_day.rename(columns=rename_dict)
            return df_day
        except Exception:
            return None

    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(process_date, d) for d in dates_list]
        for i, f in enumerate(concurrent.futures.as_completed(futures)):
            res = f.result()
            if res is not None: results.append(res)
            if i % 5 == 0:
                with log_widget: print(f"   ...processed {i}/{len(dates_list)} dates", end='\r')

    if not results: return pd.DataFrame()

    final_df = pd.concat(results, ignore_index=True)

    # --- FIX: DETERMINISTIC SORTING ---
    # Sort by Date and then by ID to ensure rows are ALWAYS in the exact same order
    if 'date' in final_df.columns and 'id' in final_df.columns:
        final_df = final_df.sort_values(by=['date', 'id']).reset_index(drop=True)

    # Ensure static predictors exist
    for col in STATIC_PREDICTORS:
        if col not in final_df.columns: final_df[col] = 0

    return final_df.fillna(0)

print("Training Engine ready (Dynamic Date Column).")

In [None]:
# @title
# --- CELL 5: PREDICTION ENGINE ---

def get_prediction_data_dynamic(best_days_n, log_widget=None):
    """
    Downloads rainfall data for the winning window.
    SMART FEATURE: If data for FORECAST_DATE_FIXED is missing (latency),
    it automatically finds and uses the MOST RECENT available date in the dataset.

    Args:
        best_days_n (int): Number of days for rainfall accumulation.
        log_widget (Output): The widget to print logs to (optional).
    """

    # Helper to print to the specific widget or standard console
    def log(msg):
        if log_widget:
            with log_widget: print(msg)
        else:
            print(msg)

    target_date = ee.Date(FORECAST_DATE_FIXED)

    # 1. Use JAXA GSMaP Collection
    gpm = ee.ImageCollection('JAXA/GPM_L3/GSMaP/v8/operational').select('hourlyPrecipRateGC')

    # --- "FIND MOST RECENT DATE" LOGIC ---
    # Search for all available images BEFORE or EQUAL to the target date.
    available_col = gpm.filterDate('2000-01-01', target_date.advance(1, 'day')) \
                       .sort('system:time_start', False) \
                       .limit(1)

    # Check if at least one image exists in the history
    has_data = available_col.size().getInfo() > 0

    found_date_str = "N/A"

    log("-" * 45)
    log("      SATELLITE DATA REPORT")
    log("-" * 45)
    log(f"Requested Date : {FORECAST_DATE_FIXED}")

    if has_data:
        # Get the most recent image found
        latest_img = available_col.first()

        # Retrieve the date of this image
        found_date_ms = latest_img.get('system:time_start').getInfo()
        found_date = ee.Date(found_date_ms)
        found_date_str = found_date.format('YYYY-MM-dd').getInfo()

        log(f"Source Image   : JAXA GSMaP v8 Operational")
        log(f"Available Date : {found_date_str}")

        # If the found date is different from the requested one
        if found_date_str != FORECAST_DATE_FIXED:
            diff = ee.Date(FORECAST_DATE_FIXED).difference(found_date, 'day').getInfo()
            log(f"STATUS         : FALLBACK ACTIVATED")
            log(f"   (Data lag of {int(diff)} days due to latency/missing data)")
            log(f"   Using rainfall accumulation ending on {found_date_str}")
        else:
            log(f"STATUS         : EXACT MATCH")

        log("-" * 45)

        # --- RAINFALL CALCULATION ---
        rain_img = gpm.filterDate(found_date.advance(-best_days_n, 'day'), found_date.advance(1, 'day')) \
                      .sum().unmask(0).rename(f'Rn{best_days_n}')

    else:
        log("STATUS         : NO DATA FOUND (Using 0 Rain)")
        log("-" * 45)
        rain_img = ee.Image.constant(0).rename(f'Rn{best_days_n}')

    # --- REDUCE REGIONS ---
    stats = rain_img.reduceRegions(
        collection=prediction_area_shp,
        reducer=ee.Reducer.mean().combine(ee.Reducer.stdDev(), sharedInputs=True),
        scale=1000, tileScale=16
    )

    df = geemap.ee_to_df(stats)

    # --- ROBUST RENAMING ---
    target_mean = f'Rn{best_days_n}_m'
    target_std = f'Rn{best_days_n}_s'

    possible_means = [f'Rn{best_days_n}_mean', 'hourlyPrecipRateGC_mean', 'mean']
    possible_stds = [f'Rn{best_days_n}_stdDev', 'hourlyPrecipRateGC_stdDev', 'stdDev']

    found_col = False
    for col in possible_means:
        if col in df.columns:
            df[target_mean] = df[col]
            found_col = True
            break
    if not found_col: df[target_mean] = 0

    found_col = False
    for col in possible_stds:
        if col in df.columns:
            df[target_std] = df[col]
            found_col = True
            break
    if not found_col: df[target_std] = 0

    for col in STATIC_PREDICTORS:
        if col not in df.columns: df[col] = 0

    return df

print("Prediction Engine ready (Log Integration Enabled).")

In [None]:
# @title
# --- CELL 6: DASHBOARD AND INTERFACE ---
from ipyleaflet import WidgetControl, ZoomControl, FullScreenControl
import base64
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix
from IPython.display import display, Javascript

# --- HELPER FUNCTIONS ---
def calc_confusion_class(row, pred_col, true_col='P/A'):
    p = int(row[pred_col])
    t = int(row[true_col])
    if p == 1 and t == 0: return 0 # FP
    if p == 0 and t == 0: return 1 # TN
    if p == 0 and t == 1: return 2 # FN
    if p == 1 and t == 1: return 3 # TP
    return 1

def create_download_link(df, title="Download CSV", filename="data.csv"):
    csv = df.to_csv(index=False)
    b64 = base64.b64encode(csv.encode()).decode()
    payload = f'data:text/csv;base64,{b64}'
    html = f'<a download="{filename}" href="{payload}" target="_blank"><button style="background-color: #00BFFF; color: #ffffff !important; padding: 6px 12px; border: none; border-radius: 4px; cursor: pointer; font-size: 12px; font-family: sans-serif; font-weight: bold;">{title}</button></a>'
    return widgets.HTML(html)

def filter_dataframe_for_export(df):
    """
    Filters the DataFrame based on the CSV_EXPORT_MODE setting in Cell 1.
    If 'BEST_ONLY', keeps only identification cols, final predictors, and results.
    """
    if CSV_EXPORT_MODE == 'ALL_DATA':
        return df

    # Mode is 'BEST_ONLY'
    # Define base identification columns
    cols_to_keep = ['id', 'date', 'P/A', 'NUM_ID']

    # Add the predictors used in the final model (Static + Best Rain)
    if 'final_predictors' in APP:
        cols_to_keep.extend(APP['final_predictors'])

    # Add Result columns (if they exist in the df)
    result_cols = ['calib_prob', 'calib_pred', 'conf_class',
                   'valid_prob', 'valid_pred', 'valid_conf',
                   'SI', 'Prediction']

    for rc in result_cols:
        if rc in df.columns:
            cols_to_keep.append(rc)

    # Filter: keep only columns that actually exist in the dataframe
    final_cols = [c for c in cols_to_keep if c in df.columns]

    return df[final_cols]

def map_values(df_res, val_col, layer_name, palette):
    m = APP['map']
    try:
        layer = m.find_layer(layer_name)
        if layer: m.remove_layer(layer)
    except: pass

    def clean_id_py(val):
        s = str(val)
        digits = re.sub(r'[^0-9]', '', s)
        return int(digits) if digits else 0

    df_map = df_res.copy()
    df_map['NUM_ID_PY'] = df_map['id'].apply(clean_id_py)

    if layer_name.startswith("Confusion"):
        df_flat = df_map.sort_values('date').drop_duplicates(subset='NUM_ID_PY', keep='last')
        is_visible = False
    else:
        df_flat = df_map.groupby('NUM_ID_PY')[val_col].max().reset_index()
        is_visible = True

    id_list = df_flat['NUM_ID_PY'].tolist()
    val_list = df_flat[val_col].tolist()

    polygons_img = predictors_polygons.reduceToImage(properties=['NUM_ID'], reducer=ee.Reducer.first())
    result_img = polygons_img.remap(id_list, val_list).rename('value')
    result_img = result_img.updateMask(result_img.gte(0))

    vis = {'palette': palette, 'min': 0, 'max': 3 if layer_name.startswith("Confusion") else 1}
    m.addLayer(result_img, vis, layer_name, shown=is_visible)

def calculate_advanced_metrics(y_true, y_probs):
    fpr, tpr, roc_thresh = roc_curve(y_true, y_probs)
    roc_auc = auc(fpr, tpr)
    youden_scores = tpr - fpr
    best_idx = np.argmax(youden_scores)
    best_thresh = roc_thresh[best_idx]
    y_pred_opt = (y_probs >= best_thresh).astype(int)
    f1 = f1_score(y_true, y_pred_opt, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred_opt)
    acc = accuracy_score(y_true, y_pred_opt)
    return {
        'auc': roc_auc, 'best_thresh': best_thresh,
        'f1': f1, 'kappa': kappa, 'acc': acc, 'youden': youden_scores[best_idx],
        'fpr': fpr, 'tpr': tpr, 'best_idx': best_idx, 'y_pred_opt': y_pred_opt
    }

# --- FLOATING PANEL LOGIC ---
APP['panels'] = {'calib': None, 'valid': None, 'pred': None}

def show_floating_panel(content_widget, title="Results", panel_key='calib', width='400px'):
    if APP['panels'][panel_key] is not None:
        try: APP['map'].remove_control(APP['panels'][panel_key])
        except: pass
        APP['panels'][panel_key] = None

    btn_expand = widgets.Button(
        description=f"{title}",
        style=widgets.ButtonStyle(button_color='#00BFFF', text_color='#ffffff', font_weight='bold'),
        layout=widgets.Layout(width='auto', padding='5px')
    )

    btn_minimize = widgets.Button(icon='minus', layout=widgets.Layout(width='30px', height='30px'), style=widgets.ButtonStyle(button_color='transparent', text_color='#000000'))

    header_title = widgets.HTML(f'<span style="font-weight: bold; color: black; font-family: sans-serif;">{title}</span>')
    header = widgets.HBox([header_title, btn_minimize], layout=widgets.Layout(justify_content='space-between', width='100%', align_items='center', border_bottom='1px solid #eee'))

    expanded_content = widgets.VBox(
        [header, content_widget],
        layout=widgets.Layout(
            width=width,
            background_color='white',
            padding='10px',
            border_radius='8px',
            border='2px solid black',
            box_shadow='0 4px 10px rgba(0,0,0,0.2)',
            max_height='600px',
            overflow_y='auto'
        )
    )
    main_container = widgets.VBox([expanded_content])

    def toggle_view(b): main_container.children = [btn_expand] if main_container.children[0] == expanded_content else [expanded_content]
    btn_minimize.on_click(toggle_view)
    btn_expand.on_click(toggle_view)

    control = WidgetControl(widget=main_container, position='bottomright')
    APP['map'].add_control(control)
    APP['panels'][panel_key] = control

# --- LEGEND ---
def create_legend_widget():
    title1 = widgets.HTML(value='<div style="font-weight: bold; font-size: 12px; color: #000000 !important; font-family: sans-serif;">Calibration/Validation/Prediction map</div>')
    colors_css = ", ".join(VIS_PALETTE)
    gradient_html = f"""
    <div style="width: 100%; margin-top: 5px;">
        <div style="display: flex; justify-content: space-between; font-size: 11px; margin-bottom: 2px; color: #000000 !important; font-family: sans-serif;">
            <span>0</span><span>1</span>
        </div>
        <div style="height: 12px; width: 200px; background: linear-gradient(to right, {colors_css}); border: 1px solid #999;"></div>
    </div>
    """
    gradient_widget = widgets.HTML(gradient_html)
    title2 = widgets.HTML(value='<div style="font-weight: bold; font-size: 12px; margin-top: 15px; color: #000000 !important; font-family: sans-serif;">Confusion map</div>')
    labels = ["False Positive", "True Negative", "False Negative", "True Positive"]
    colors = PALETTE_CONFUSION
    legend_items = []
    for col, txt in zip(colors, labels):
        item_html = f"""
        <div style="display: flex; align-items: center; margin: 2px 0;">
            <div style="width: 15px; height: 15px; background-color: {col}; border: 1px solid #999; margin-right: 8px;"></div>
            <span style="font-size: 11px; color: #000000 !important; font-family: sans-serif;">{txt}</span>
        </div>
        """
        legend_items.append(widgets.HTML(item_html))
    legend_content = widgets.VBox(
        [title1, gradient_widget, title2] + legend_items,
        layout=widgets.Layout(width='250px', padding='10px', background_color='white', border='2px solid black', border_radius='4px')
    )
    return legend_content

# --- BUTTON HANDLERS ---
out_log = widgets.Output(layout={'border': '1px solid #ccc', 'height': '150px', 'overflow_y': 'scroll', 'padding': '10px', 'margin': '10px 0'})

def on_calib_click(b):
    out_log.clear_output()
    if APP['df'] is None: APP['df'] = download_training_data_server_side(out_log)
    df = APP['df']
    if df.empty: return

    y = df['P/A']
    best_auc = 0
    best_days = MIN_DAYS

    with out_log:
        print(f"Starting Optimization: Scanning windows {MIN_DAYS}-{MAX_DAYS} days...")
        print("-" * 30)

    for days in range(MIN_DAYS, MAX_DAYS + 1):
        current_rain = [f'Rn{days}_m', f'Rn{days}_s']
        current_predictors = current_rain
        if not all(col in df.columns for col in current_rain): continue
        X_temp = df[current_predictors].fillna(0)
        rf_temp = RandomForestClassifier(n_estimators=50, max_depth=8, random_state=42, n_jobs=-1, class_weight='balanced')
        rf_temp.fit(X_temp, y)
        probs = rf_temp.predict_proba(X_temp)[:, 1]
        fpr, tpr, _ = roc_curve(y, probs)
        current_auc = auc(fpr, tpr)

        with out_log: print(f"   > Day {days}: AUC = {current_auc:.4f}")

        if current_auc > best_auc:
            best_auc = current_auc
            best_days = days

    APP['best_window'] = best_days
    APP['final_predictors'] = STATIC_PREDICTORS + [f'Rn{best_days}_m', f'Rn{best_days}_s']

    with out_log:
        print("-" * 30)
        print(f"FINAL SELECTION:")
        print(f"   Best Window: {best_days} Days")
        print(f"   Max AUC:     {best_auc:.4f}")
        print("-" * 30)
        print("Training final robust model...")

    X = df[APP['final_predictors']].fillna(0)
    rf_final = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1, oob_score=True, class_weight='balanced')
    rf_final.fit(X, y)
    APP['model'] = rf_final

    y_probs = rf_final.oob_decision_function_[:, 1] if hasattr(rf_final, "oob_decision_function_") else rf_final.predict_proba(X)[:, 1]
    m = calculate_advanced_metrics(y, y_probs)
    imp = pd.Series(rf_final.feature_importances_, index=APP['final_predictors']).sort_values()

    cm = confusion_matrix(y, m['y_pred_opt'])

    # --- CALIBRATION PLOT ---
    fig = make_subplots(
        rows=3, cols=1,
        subplot_titles=(f"Feature Importance (Best: {best_days}d)", "Confusion Matrix", "ROC Curve"),
        vertical_spacing=0.1
    )

    fig.add_trace(go.Bar(x=imp.values, y=imp.index, orientation='h', marker=dict(color='orange'), name="Imp"), row=1, col=1)

    fig.add_trace(go.Heatmap(z=cm, x=['Pred:0', 'Pred:1'], y=['True:0', 'True:1'], colorscale='Blues', showscale=False, text=[[str(y) for y in x] for x in cm], texttemplate="%{text}"), row=2, col=1)
    fig.update_yaxes(autorange="reversed", row=2, col=1)

    fig.add_trace(go.Scatter(x=m['fpr'], y=m['tpr'], mode='lines', line=dict(color='royalblue', width=2), fill='tozeroy', name="ROC"), row=3, col=1)
    fig.add_trace(go.Scatter(x=[m['fpr'][m['best_idx']]], y=[m['tpr'][m['best_idx']]], mode='markers', marker=dict(color='red', size=12, symbol='star'), name="Opt"), row=3, col=1)
    fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(color='black', dash='dash'), showlegend=False), row=3, col=1)

    fig.update_layout(template="plotly_white", height=900, showlegend=False, margin=dict(l=10, r=10, t=30, b=10))

    metrics_html = widgets.HTML(f"""
    <div style="color: black; font-family: sans-serif; font-size: 11px; margin: 10px 0; line-height: 1.6;">
        <strong>Selected Rainfall Window:</strong> {best_days} Days<br>
        <strong>Training overall accuracy:</strong> {m['acc']:.4f}<br>
        <strong>Area under curve:</strong> {m['auc']:.4f}<br>
        <strong>Youden Index:</strong> {m['youden']:.4f}<br>
        <strong>Cohen's Kappa:</strong> {m['kappa']:.4f}<br>
        <strong>F1 Score:</strong> {m['f1']:.4f}
    </div>
    """)

    df['calib_prob'] = rf_final.predict_proba(X)[:, 1]
    df['calib_pred'] = m['y_pred_opt']
    df['conf_class'] = df.apply(lambda r: calc_confusion_class(r, 'calib_pred'), axis=1)

    # APPLY FILTER BEFORE DOWNLOAD
    df_export = filter_dataframe_for_export(df)
    download_btn = create_download_link(df_export, "Download Calibration CSV", "calibration.csv")

    show_floating_panel(
        widgets.VBox([metrics_html, go.FigureWidget(fig), download_btn], layout=widgets.Layout(align_items='stretch')),
        title="Calibration Results",
        panel_key='calib'
    )

    map_values(df, 'calib_prob', 'Calibration Map', VIS_PALETTE)
    map_values(df, 'conf_class', 'Confusion Calibration', PALETTE_CONFUSION)

def on_valid_click(b):
    out_log.clear_output()
    with out_log: print("Starting Validation process...")

    if APP['model'] is None:
        with out_log: print("ERROR: You must run Calibration first!")
        return

    if 'final_predictors' not in APP or not APP['final_predictors']:
        with out_log: print("ERROR: Predictors not found. Run Calibration again.")
        return

    try:
        df = APP['df']
        X = df[APP['final_predictors']].fillna(0)
        y = df['P/A']

        with out_log: print("Running Cross-Validation (10-Folds)... please wait.")
        y_probs = cross_val_predict(APP['model'], X, y, cv=StratifiedKFold(n_splits=10, shuffle=True), method='predict_proba', n_jobs=-1)[:, 1]
        m = calculate_advanced_metrics(y, y_probs)

        cm = confusion_matrix(y, m['y_pred_opt'])

        metrics_html = widgets.HTML(f"""
        <div style="color: black; font-family: sans-serif; font-size: 11px; margin: 10px 0; line-height: 1.6;">
            <strong>Training overall accuracy:</strong> {m['acc']:.4f}<br>
            <strong>Area under curve:</strong> {m['auc']:.4f}<br>
            <strong>Youden Index:</strong> {m['youden']:.4f}<br>
            <strong>Cohen's Kappa:</strong> {m['kappa']:.4f}<br>
            <strong>F1 Score:</strong> {m['f1']:.4f}
        </div>
        """)

        fig = make_subplots(rows=2, cols=1, subplot_titles=("Confusion Matrix", "CV ROC Curve"), vertical_spacing=0.15)

        fig.add_trace(go.Heatmap(z=cm, x=['Pred:0', 'Pred:1'], y=['True:0', 'True:1'], colorscale='Blues', showscale=False, text=[[str(y) for y in x] for x in cm], texttemplate="%{text}"), row=1, col=1)
        fig.update_yaxes(autorange="reversed", row=1, col=1)

        fig.add_trace(go.Scatter(x=m['fpr'], y=m['tpr'], mode='lines', line=dict(color='darkorange', width=2), fill='tozeroy', name="ROC"), row=2, col=1)
        fig.add_trace(go.Scatter(x=[m['fpr'][m['best_idx']]], y=[m['tpr'][m['best_idx']]], mode='markers', marker=dict(color='red', size=12, symbol='star'), name="Opt"), row=2, col=1)
        fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(color='black', dash='dash'), showlegend=False), row=2, col=1)

        fig.update_layout(template="plotly_white", height=500, showlegend=False, margin=dict(l=10, r=10, t=30, b=10))

        df['valid_prob'] = y_probs
        df['valid_pred'] = m['y_pred_opt']
        df['valid_conf'] = df.apply(lambda r: calc_confusion_class(r, 'valid_pred'), axis=1)

        # APPLY FILTER BEFORE DOWNLOAD
        df_export = filter_dataframe_for_export(df)
        download_btn = create_download_link(df_export, "Download Validation CSV", "validation.csv")

        show_floating_panel(
            widgets.VBox([metrics_html, go.FigureWidget(fig), download_btn], layout=widgets.Layout(align_items='stretch')),
            title="Validation Results",
            panel_key='valid'
        )

        map_values(df, 'valid_prob', 'Validation Map', VIS_PALETTE)
        map_values(df, 'valid_conf', 'Confusion Validation', PALETTE_CONFUSION)
        with out_log: print("Validation Done.")

    except Exception as e:
        with out_log: print(f"SYSTEM ERROR: {e}")

def on_pred_click(b):
    if APP['model'] is None:
        with out_log: print("Model missing.")
        return
    out_log.clear_output()
    best_w = APP['best_window']

    try:
        df_pred = get_prediction_data_dynamic(best_w, out_log)

        X_map = df_pred[APP['final_predictors']].fillna(0)
        probs = APP['model'].predict_proba(X_map)[:, 1]
        df_pred['SI'] = probs
        map_values(df_pred, 'SI', 'Prediction Map', VIS_PALETTE)

        risk_html = widgets.HTML(f"""<div style="color: black; font-family: sans-serif; font-size: 14px; margin-bottom: 10px; font-weight: bold; text-align: center;">Max Risk Score: {probs.max():.2f}</div>""")

        # Prediction CSV already only has 1 rainfall column, but we filter just to be consistent with 'BEST_ONLY' setting
        df_export = filter_dataframe_for_export(df_pred)
        download_btn = create_download_link(df_export, "Download Prediction CSV", "prediction.csv")

        show_floating_panel(
            widgets.VBox([risk_html, download_btn]),
            title="Prediction", width='220px', panel_key='pred'
        )
        with out_log: print(f"Prediction Done. Max Risk: {probs.max():.2f}")
    except Exception as e:
        with out_log: print(f"Error: {e}")

# --- INIT MAP ---
Map = geemap.Map(height='900px', zoom_control=False, draw_control=False, fullscreen_control=True)
Map.centerObject(prediction_area_shp, 10)
Map.add_control(ZoomControl(position='bottomleft'))
try:
    if hasattr(Map, 'layer_control'): Map.remove_control(Map.layer_control)
except: pass
APP['map'] = Map

btn_style = widgets.ButtonStyle(button_color='#00BFFF', text_color='#ffffff', font_weight='bold')
layout = widgets.Layout(width='180px', margin='2px')
b_init = widgets.Button(description="Run Analysis", layout=layout, style=btn_style)
b_calib = widgets.Button(description="Run Calibration", layout=layout, style=btn_style)
b_valid = widgets.Button(description="Run Validation", layout=layout, style=btn_style)
b_pred = widgets.Button(description="Run Prediction", layout=layout, style=btn_style)

b_calib.on_click(on_calib_click)
b_valid.on_click(on_valid_click)
b_pred.on_click(on_pred_click)

start_ctrl = WidgetControl(widget=widgets.VBox([b_init], layout=widgets.Layout(padding='5px', background_color='white', border_radius='8px')), position='bottomright')
main_ctrl = WidgetControl(widget=widgets.HBox([b_calib, b_valid, b_pred], layout=widgets.Layout(padding='5px', background_color='white', border_radius='8px')), position='bottomright')

def activate_analysis(b):
    out_log.clear_output()
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 5000});'''))

    Map.addLayer(predictors_polygons.style(**{'color': 'gray', 'fillColor': '00000000'}), {}, 'Study Area')
    try:
        Map.remove_control(start_ctrl)
        Map.add_control(main_ctrl)
        legend_widget = create_legend_widget()
        legend_control = WidgetControl(widget=legend_widget, position='bottomleft')
        Map.add_control(legend_control)
    except Exception as e: print(e)

b_init.on_click(activate_analysis)
Map.add_control(start_ctrl)

display(widgets.VBox([widgets.HTML('<div style="font-weight: bold; color: black; font-family: sans-serif;">OPERATION LOG</div>'), out_log]))
display(Map)