In [None]:
# --- CELL 1: IMPORTS AND SETUP ---
import ee
import geemap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import re
import concurrent.futures
from IPython.display import display, clear_output, Javascript

# Machine Learning & Metrics (Scikit-Learn)
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict, StratifiedKFold
from sklearn.metrics import roc_curve, auc, confusion_matrix, f1_score, cohen_kappa_score, accuracy_score, precision_recall_curve

# Interactive Plotting (Plotly)
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Initialize Google Earth Engine
# Tries to attach to a specific project if defined, otherwise falls back to default auth.
try:
    # REPLACE 'your-google-cloud-project' with your actual project ID if needed
    ee.Initialize()
    print("âœ… Connected to Earth Engine.")
except:
    ee.Authenticate()
    ee.Initialize()
    print("âœ… Authentication completed.")

In [None]:
# --- CELL 2: CONFIGURATION ---

# -------------------------------------------------------------------------
# USER CONFIGURATION: Replace these paths with your Earth Engine Assets
# -------------------------------------------------------------------------
# 1. Polygons: The grid units (Slope Units or Grid Cells) containing static morphological data.
#    Must contain the columns listed in PREDICTORS.
polygons_asset = "projects/ee-your-username/assets/your_study_area_polygons"

# 2. Points: Historical landslide inventory points.
#    Must contain a date column (e.g., 'formatted_date') to link with rainfall.
points_asset = "projects/ee-your-username/assets/your_landslide_inventory"

# 3. Prediction Area: Usually the same as polygons_asset, but can be a different region.
prediction_asset = "projects/ee-your-username/assets/your_study_area_polygons"

# Forecast Date for the "Run Prediction" mode
FORECAST_DATE_FIXED = '2025-01-01'

# List of Predictor Variables (Feature Vector)
# Ensure these property names exist in your 'polygons_asset'
PREDICTORS = [
    'Slope_mean',      # Example: Morphological slope
    'Relief_mean',     # Example: Relief/Elevation
    'NDVI_mean',       # Example: Vegetation Index
    'Lithology_Code',  # Example: Categorical lithology
    'Rn7_m',           # Dynamic: 7-day cumulative rainfall (calculated by script)
    'Rn7_s',           # Dynamic: 7-day rainfall std dev
    'Rn14_m',          # Dynamic: 14-day cumulative rainfall
    'Rn14_s'           # Dynamic: 14-day rainfall std dev
]

# -------------------------------------------------------------------------
# VISUALIZATION PARAMETERS
# -------------------------------------------------------------------------
# Color palette for Probability Map (Green -> White -> Red)
VIS_PALETTE = [
    '#006b0b', '#1b7b25', '#4e9956', '#dbeadd', '#ffffff',
    '#f0b2ae', '#eb958f', '#df564d', '#d10e00'
]

# Color palette for Confusion Matrix on Map (FP, TN, FN, TP)
PALETTE_CONFUSION = ['#D10E00', '#DF564D', '#DBEADD', '#006B0B']

# Load Collections
raw_polygons = ee.FeatureCollection(polygons_asset)
landPoints = ee.FeatureCollection(points_asset)
raw_prediction = ee.FeatureCollection(prediction_asset)

# Helper: Ensures polygons have a numeric ID for rasterization
def add_numeric_id(feature):
    """
    Parses the string ID of a feature to a number to allow 'reduceToImage' operations.
    Adjust regex as needed based on your ID format.
    """
    str_id = ee.String(feature.get('id'))
    # Removes non-numeric characters to extract ID
    num_str = str_id.replace(r'[^0-9]', '', 'g')
    num_val = ee.Algorithms.If(num_str.length().gt(0), ee.Number.parse(num_str), 0)
    return feature.set('NUM_ID', num_val)

predictors_polygons = raw_polygons.map(add_numeric_id)
prediction_area_shp = raw_prediction.map(add_numeric_id)

print("âœ… Assets and Configuration loaded.")

In [None]:
# --- CELL 3: HELPER FUNCTIONS ---

# Global Application State Holder
APP = {
    'df': None,       # Stores the training dataframe
    'model': None,    # Stores the trained Random Forest model
    'map': None,      # Reference to the map widget
    'current_panel': None # Reference to the currently open floating panel
}

# 1. Map Visualization Helper
def map_values(df_res, val_col, layer_name, palette):
    """
    Rasterizes the results (DataFrame) onto the Earth Engine Map.
    Links the local Pandas DataFrame results back to the EE Polygons via 'NUM_ID'.
    """
    m = APP['map']
    if m is None: return

    # Cleanup existing layer with same name
    try:
        layer = m.find_layer(layer_name)
        if layer: m.remove_layer(layer)
    except: pass

    # Helper to clean IDs locally for matching
    def clean_id_py(val):
        s = str(val)
        digits = re.sub(r'[^0-9]', '', s)
        return int(digits) if digits else 0

    df_map = df_res.copy()
    df_map['NUM_ID_PY'] = df_map['id'].apply(clean_id_py)

    # Handling Multiple Events vs Single Prediction
    if layer_name.startswith("Confusion"):
        # For validation, we might have duplicates; keep the last event
        df_flat = df_map.sort_values('date').drop_duplicates(subset='NUM_ID_PY', keep='last')
        is_visible = False # Confusion map hidden by default
    else:
        # For prediction/probability, take max risk
        df_flat = df_map.groupby('NUM_ID_PY')[val_col].max().reset_index()
        is_visible = True

    id_list = df_flat['NUM_ID_PY'].tolist()
    val_list = df_flat[val_col].tolist()

    # Create an empty image painted with IDs
    polygons_img = predictors_polygons.reduceToImage(properties=['NUM_ID'], reducer=ee.Reducer.first())

    # Remap IDs to calculated values (Probability or Class)
    result_img = polygons_img.remap(id_list, val_list).rename('value')
    result_img = result_img.updateMask(result_img.gte(0)) # Mask no-data

    # Visualization range: 0-1 for Prob, 0-3 for Confusion Matrix
    v_max = 3 if layer_name.startswith("Confusion") else 1
    vis = {'palette': palette, 'min': 0, 'max': v_max}

    m.addLayer(result_img, vis, layer_name, shown=is_visible)

# 2. Confusion Class Calculator
def calc_confusion_class(row, pred_col, true_col='P/A'):
    """
    Determines the confusion matrix class for mapping:
    0: FP (Red), 1: TN (Pink/White), 2: FN (Grey), 3: TP (Green)
    """
    p = int(row[pred_col])
    t = int(row[true_col])
    if p == 1 and t == 0: return 0 # False Positive
    if p == 0 and t == 0: return 1 # True Negative
    if p == 0 and t == 1: return 2 # False Negative
    if p == 1 and t == 1: return 3 # True Positive
    return 1

# 3. CSV Download Helper
def create_download_link(df, title="Download CSV", filename="data.csv"):
    """Creates a base64 encoded HTML button for downloading DataFrames."""
    csv = df.to_csv(index=False)
    b64 = base64.b64encode(csv.encode()).decode()
    payload = f'data:text/csv;base64,{b64}'
    # Style: Sky blue background, WHITE text
    html = f'<a download="{filename}" href="{payload}" target="_blank"><button style="background-color: #00BFFF; color: #ffffff !important; padding: 6px 12px; border: none; border-radius: 4px; cursor: pointer; font-size: 12px; font-family: sans-serif; font-weight: bold;">{title}</button></a>'
    return widgets.HTML(html)

print("âœ… Helpers loaded.")

In [None]:
# --- CELL 4: TRAINING ENGINE ---

def download_training_data_server_side(log_widget):
    """
    Iterates over unique dates in the landslide inventory.
    For each date, computes antecedent rainfall (GPM) and extracts features
    from polygons to build the training dataset.
    """
    # 1. Get unique dates from points
    raw_dates = landPoints.aggregate_array('formatted_date').distinct().getInfo()
    dates_list = [str(d)[:10] for d in raw_dates]

    with log_widget:
        print(f"ðŸ“… Event Dates found: {len(dates_list)}")
        print("ðŸš€ Starting Training Set calculation (GPM + Morphology)...")

    def process_date(date_str):
        try:
            d = ee.Date(date_str)
            gpm = ee.ImageCollection('JAXA/GPM_L3/GSMaP/v8/operational').select('hourlyPrecipRateGC')

            # Calculate 7-day and 14-day cumulative rainfall
            img7 = gpm.filterDate(d.advance(-7, 'day'), d).sum().unmask(0).rename('Rn7_m')
            img14 = gpm.filterDate(d.advance(-14, 'day'), d).sum().unmask(0).rename('Rn14_m')
            combined = img7.addBands(img14)

            # Filter points for this specific date
            todays_points = landPoints.filter(ee.Filter.eq('formatted_date', date_str))

            # Labeling: If polygon contains a point -> P/A = 1, else 0
            def map_polygons(poly):
                count = todays_points.filterBounds(poly.geometry()).size()
                return poly.set({'P/A': ee.Algorithms.If(count.gt(0), 1, 0), 'date': date_str})

            labeled_polys = predictors_polygons.map(map_polygons)

            # Reduce regions (Extract raster values to vectors)
            stats = combined.reduceRegions(
                collection=labeled_polys,
                reducer=ee.Reducer.mean().combine(ee.Reducer.stdDev(), sharedInputs=True),
                scale=1000, tileScale=16
            )

            df_day = geemap.ee_to_df(stats)
            if df_day.empty: return None

            # Rename columns to match PREDICTORS list
            rename = {'Rn7_m_mean': 'Rn7_m', 'Rn7_m_stdDev': 'Rn7_s',
                      'Rn14_m_mean': 'Rn14_m', 'Rn14_m_stdDev': 'Rn14_s'}
            df_day = df_day.rename(columns=rename)

            # Fill missing static predictors with 0 if necessary
            for col in PREDICTORS:
                if col not in df_day.columns and 'Rn' in col: df_day[col] = 0

            return df_day
        except Exception as e:
            return None

    # Multi-threading for speed
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(process_date, d) for d in dates_list]
        for i, f in enumerate(concurrent.futures.as_completed(futures)):
            res = f.result()
            if res is not None: results.append(res)
            if i % 5 == 0:
                with log_widget: print(f"   ...processed {i}/{len(dates_list)} dates", end='\r')

    if not results: return pd.DataFrame()
    return pd.concat(results, ignore_index=True)

print("âœ… Training Engine ready.")

In [None]:
# --- CELL 5: PREDICTION ENGINE ---

def get_prediction_data_fixed():
    """
    Computes rainfall for the specific FORECAST_DATE_FIXED and prepares
    the dataset for the entire study area (without P/A labels).
    """
    d = ee.Date(FORECAST_DATE_FIXED)
    gpm = ee.ImageCollection('JAXA/GPM_L3/GSMaP/v8/operational').select('hourlyPrecipRateGC')

    # Calculate rainfall bands
    img7 = gpm.filterDate(d.advance(-7, 'day'), d).sum().unmask(0).rename('Rn7_m')
    img14 = gpm.filterDate(d.advance(-14, 'day'), d).sum().unmask(0).rename('Rn14_m')
    combined = img7.addBands(img14)

    # Extract features
    stats = combined.reduceRegions(
        collection=prediction_area_shp,
        reducer=ee.Reducer.mean().combine(ee.Reducer.stdDev(), sharedInputs=True),
        scale=1000, tileScale=16
    )

    df = geemap.ee_to_df(stats)

    # Standardize column names
    rename = {'Rn7_m_mean': 'Rn7_m', 'Rn7_m_stdDev': 'Rn7_s',
              'Rn14_m_mean': 'Rn14_m', 'Rn14_m_stdDev': 'Rn14_s'}
    df = df.rename(columns=rename)

    for col in PREDICTORS:
        if col not in df.columns: df[col] = 0

    return df

print("âœ… Prediction Engine ready.")

In [None]:
# --- CELL 6: DASHBOARD AND INTERFACE ---

# ---------------------------------------------------------
# 1. METRICS CALCULATION
# ---------------------------------------------------------
def calculate_advanced_metrics(y_true, y_probs):
    """
    Calculates key performance indicators.
    Optimizes the decision threshold using Youden's J statistic (Sensitivity + Specificity - 1).
    """
    fpr, tpr, roc_thresh = roc_curve(y_true, y_probs)
    roc_auc = auc(fpr, tpr)

    # Youden Index Optimization
    youden_scores = tpr - fpr
    best_idx = np.argmax(youden_scores)
    best_thresh = roc_thresh[best_idx]
    max_youden = youden_scores[best_idx]

    # Apply optimal threshold
    y_pred_opt = (y_probs >= best_thresh).astype(int)

    # Compute metrics
    f1 = f1_score(y_true, y_pred_opt, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred_opt)
    acc = accuracy_score(y_true, y_pred_opt)

    return {
        'auc': roc_auc, 'best_thresh': best_thresh,
        'f1': f1, 'kappa': kappa, 'youden': max_youden, 'acc': acc,
        'fpr': fpr, 'tpr': tpr, 'roc_thresh': roc_thresh, 'best_idx': best_idx, 'y_pred_opt': y_pred_opt
    }

# ---------------------------------------------------------
# 2. CUSTOM UI WIDGETS
# ---------------------------------------------------------
def create_legend_widget():
    """
    Creates a fixed HTML Legend with black text and border.
    Includes a gradient bar for probability and a categorical legend for the Confusion Matrix.
    """
    # Titles using HTML to force color styling
    title1 = widgets.HTML(value='<div style="font-weight: bold; font-size: 12px; color: #000000 !important; font-family: sans-serif;">Calibration/Validation/Prediction map</div>')
    title2 = widgets.HTML(value='<div style="font-weight: bold; font-size: 12px; margin-top: 15px; color: #000000 !important; font-family: sans-serif;">Confusion map</div>')

    # Gradient Bar
    colors_css = ", ".join(VIS_PALETTE)
    gradient_html = f"""
    <div style="width: 100%; margin-top: 5px;">
        <div style="display: flex; justify-content: space-between; font-size: 11px; margin-bottom: 2px; color: #000000 !important; font-family: sans-serif;">
            <span>0</span><span>1</span>
        </div>
        <div style="height: 12px; width: 200px; background: linear-gradient(to right, {colors_css}); border: 1px solid #999;"></div>
    </div>
    """
    gradient_widget = widgets.HTML(gradient_html)

    # Confusion Matrix Items
    labels = ["False Positive", "True Negative", "False Negative", "True Positive"]
    colors = PALETTE_CONFUSION
    legend_items = []
    for col, txt in zip(colors, labels):
        item_html = f"""
        <div style="display: flex; align-items: center; margin: 2px 0;">
            <div style="width: 15px; height: 15px; background-color: {col}; border: 1px solid #999; margin-right: 8px;"></div>
            <span style="font-size: 11px; color: #000000 !important; font-family: sans-serif;">{txt}</span>
        </div>
        """
        legend_items.append(widgets.HTML(item_html))

    return widgets.VBox([title1, gradient_widget, title2] + legend_items,
                        layout=widgets.Layout(width='250px', padding='10px', background_color='white', border='2px solid black', border_radius='4px'))

def show_floating_panel(content_widget, title="Results", position='topright', width='400px'):
    """
    Manages the floating results panel.
    Replaces existing panel content to avoid overlap.
    """
    if APP['current_panel'] is not None:
        try: APP['map'].remove_control(APP['current_panel'])
        except: pass
        APP['current_panel'] = None

    # Toggle Buttons (Expanded/Minimized)
    btn_expand = widgets.Button(description=f"ðŸ“‚ {title}", style=widgets.ButtonStyle(button_color='#00BFFF', text_color='#ffffff', font_weight='bold'), layout=widgets.Layout(width='auto', padding='5px'))
    btn_minimize = widgets.Button(icon='minus', layout=widgets.Layout(width='30px', height='30px'), style=widgets.ButtonStyle(button_color='transparent', text_color='#000000'))

    header = widgets.HBox([widgets.HTML(f'<span style="font-weight: bold; color: black; font-family: sans-serif;">{title}</span>'), btn_minimize],
                          layout=widgets.Layout(justify_content='space-between', width='100%', align_items='center', border_bottom='1px solid #eee'))

    expanded_content = widgets.VBox([header, content_widget],
                                    layout=widgets.Layout(width=width, background_color='white', padding='10px', border_radius='8px', border='2px solid black', box_shadow='0 4px 10px rgba(0,0,0,0.2)', max_height='900px', overflow_y='auto'))

    main_container = widgets.VBox([expanded_content])

    # Toggle Logic
    def toggle_view(b):
        main_container.children = [btn_expand] if main_container.children[0] == expanded_content else [expanded_content]

    btn_minimize.on_click(toggle_view)
    btn_expand.on_click(toggle_view)

    control = WidgetControl(widget=main_container, position=position)
    APP['map'].add_control(control)
    APP['current_panel'] = control

# ---------------------------------------------------------
# 3. EVENT HANDLERS
# ---------------------------------------------------------
out_log = widgets.Output(layout={'border': '1px solid #ccc', 'height': '150px', 'overflow_y': 'scroll', 'padding': '10px', 'margin': '10px 0'})
dashboard_ui = widgets.VBox([widgets.HTML('<div style="font-weight: bold; color: black; font-family: sans-serif;">OPERATION LOG</div>'), out_log])

def on_calib_click(b):
    out_log.clear_output()
    if APP['df'] is None: APP['df'] = download_training_data_server_side(out_log)
    df = APP['df']
    if df.empty: return

    with out_log: print(f"Training Random Forest ({len(df)} samples)...")
    X = df[PREDICTORS].fillna(0)
    y = df['P/A']

    # Train Model (Balanced Class Weight to handle imbalance)
    rf = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1, oob_score=True, class_weight='balanced')
    rf.fit(X, y)
    APP['model'] = rf

    # Metrics
    y_probs = rf.oob_decision_function_[:, 1] if hasattr(rf, "oob_decision_function_") else rf.predict_proba(X)[:, 1]
    m = calculate_advanced_metrics(y, y_probs)
    imp = pd.Series(rf.feature_importances_, index=PREDICTORS).sort_values()

    # Create Charts
    fig = make_subplots(rows=2, cols=1, subplot_titles=("Feature Importance", "ROC Curve"), vertical_spacing=0.15)
    fig.add_trace(go.Bar(x=imp.values, y=imp.index, orientation='h', marker=dict(color='orange'), name="Imp"), row=1, col=1)
    fig.add_trace(go.Scatter(x=m['fpr'], y=m['tpr'], mode='lines', line=dict(color='royalblue', width=2), fill='tozeroy', name="ROC"), row=2, col=1)
    fig.add_trace(go.Scatter(x=[m['fpr'][m['best_idx']]], y=[m['tpr'][m['best_idx']]], mode='markers', marker=dict(color='red', size=12, symbol='star'), name="Opt"), row=2, col=1)
    fig.update_layout(template="plotly_white", height=350, showlegend=False, margin=dict(l=10, r=10, t=30, b=10))

    # Metrics HTML Widget
    metrics_html = widgets.HTML(f"""
    <div style="color: black; font-family: sans-serif; font-size: 11px; margin: 10px 0; line-height: 1.6;">
        <strong>Training overall accuracy:</strong> {m['acc']:.4f}<br>
        <strong>Area under curve:</strong> {m['auc']:.4f}<br>
        <strong>Youden Index:</strong> {m['youden']:.4f}<br>
        <strong>Cohen's Kappa:</strong> {m['kappa']:.4f}<br>
        <strong>F1 Score:</strong> {m['f1']:.4f}
    </div>
    """)

    download_btn = create_download_link(df, "Download Calibration CSV", "suscFitData.csv")
    show_floating_panel(widgets.VBox([metrics_html, go.FigureWidget(fig), download_btn], layout=widgets.Layout(align_items='stretch')), title="Calibration Results")

    # Map Update
    df['calib_prob'] = rf.predict_proba(X)[:, 1]
    df['calib_pred'] = m['y_pred_opt']
    df['conf_class'] = df.apply(lambda r: calc_confusion_class(r, 'calib_pred'), axis=1)
    with out_log:
        print(f"Calibration Done.")
        map_values(df, 'calib_prob', 'Calibration Map', VIS_PALETTE)
        map_values(df, 'conf_class', 'Confusion Calibration', PALETTE_CONFUSION)

def on_valid_click(b):
    if APP['model'] is None:
        with out_log: print("Run Calibration first.")
        return
    df = APP['df']
    with out_log: print("Running Cross-Validation (5-Folds)...")

    X = df[PREDICTORS].fillna(0)
    y = df['P/A']

    # 5-Fold Stratified CV
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    y_probs = cross_val_predict(APP['model'], X, y, cv=cv, method='predict_proba', n_jobs=-1)[:, 1]

    m = calculate_advanced_metrics(y, y_probs)
    cm = confusion_matrix(y, m['y_pred_opt'])

    metrics_html = widgets.HTML(f"""
    <div style="color: black; font-family: sans-serif; font-size: 11px; margin: 10px 0; line-height: 1.6;">
        <strong>Training overall accuracy:</strong> {m['acc']:.4f}<br>
        <strong>Area under curve:</strong> {m['auc']:.4f}<br>
        <strong>Youden Index:</strong> {m['youden']:.4f}<br>
        <strong>Cohen's Kappa:</strong> {m['kappa']:.4f}<br>
        <strong>F1 Score:</strong> {m['f1']:.4f}
    </div>
    """)

    fig = make_subplots(rows=2, cols=1, subplot_titles=("CV ROC Curve", "Confusion Matrix"), vertical_spacing=0.15)
    fig.add_trace(go.Scatter(x=m['fpr'], y=m['tpr'], mode='lines', line=dict(color='darkorange', width=2), fill='tozeroy', name="ROC"), row=1, col=1)
    fig.add_trace(go.Scatter(x=[m['fpr'][m['best_idx']]], y=[m['tpr'][m['best_idx']]], mode='markers', marker=dict(color='red', size=12, symbol='star'), name="Opt"), row=1, col=1)
    z_text = [[str(y) for y in x] for x in cm]
    fig.add_trace(go.Heatmap(z=cm, x=['Pred:0', 'Pred:1'], y=['True:0', 'True:1'], colorscale='Blues', showscale=False, text=z_text, texttemplate="%{text}"), row=2, col=1)
    fig.update_yaxes(autorange="reversed", row=2, col=1)
    fig.update_layout(template="plotly_white", height=350, showlegend=False, margin=dict(l=10, r=10, t=30, b=10))

    download_btn = create_download_link(df, "Download Validation CSV", "suscValidData.csv")
    show_floating_panel(widgets.VBox([metrics_html, go.FigureWidget(fig), download_btn], layout=widgets.Layout(align_items='stretch')), title="Validation Results")

    df['valid_prob'] = y_probs
    df['valid_pred'] = m['y_pred_opt']
    df['valid_conf'] = df.apply(lambda r: calc_confusion_class(r, 'valid_pred'), axis=1)
    with out_log:
        print("Validation Done.")
        map_values(df, 'valid_prob', 'Validation Map', VIS_PALETTE)
        map_values(df, 'valid_conf', 'Confusion Validation', PALETTE_CONFUSION)

def on_pred_click(b):
    if APP['model'] is None:
        with out_log: print("Model missing.")
        return
    out_log.clear_output()
    with out_log: print(f"Forecasting on {FORECAST_DATE_FIXED}...")
    try:
        df_pred = get_prediction_data_fixed()
        X_map = df_pred[PREDICTORS].fillna(0)
        probs = APP['model'].predict_proba(X_map)[:, 1]

        df_pred['SI'] = probs
        map_values(df_pred, 'SI', 'Prediction Map', VIS_PALETTE)

        risk_html = widgets.HTML(f"""<div style="color: black; font-family: sans-serif; font-size: 14px; margin-bottom: 10px; font-weight: bold; text-align: center;">Max Risk Score: {probs.max():.2f}</div>""")
        download_btn = create_download_link(df_pred, "Download Prediction CSV", "suscPredData.csv")
        show_floating_panel(widgets.VBox([widgets.HTML('<span style="color:black;">Prediction Ready</span>'), risk_html, download_btn]), title="Prediction", width='220px')

        with out_log: print(f"Max Risk Score: {probs.max():.2f}")
    except Exception as e:
        with out_log: print(f"Error: {e}")

# ---------------------------------------------------------
# 4. MAP INITIALIZATION
# ---------------------------------------------------------
# Set map height to 1400px for ample vertical space
Map = geemap.Map(height='1400px', zoom_control=False, draw_control=False, fullscreen_control=True)
Map.centerObject(prediction_area_shp, 9)
Map.add_control(ZoomControl(position='bottomleft'))
try:
    if hasattr(Map, 'layer_control'): Map.remove_control(Map.layer_control)
except: pass
APP['map'] = Map

# ---------------------------------------------------------
# 5. CONTROL LAYOUT
# ---------------------------------------------------------
# Button Styling: Sky Blue with White Text
btn_style = widgets.ButtonStyle(button_color='#00BFFF', text_color='#ffffff', font_weight='bold')
layout = widgets.Layout(width='180px', margin='2px')

b_init = widgets.Button(description="Run Analysis", layout=layout, style=btn_style)
b_calib = widgets.Button(description="Run Calibration", layout=layout, style=btn_style)
b_valid = widgets.Button(description="Run Validation", layout=layout, style=btn_style)
b_pred = widgets.Button(description="Run Prediction", layout=layout, style=btn_style)

b_calib.on_click(on_calib_click)
b_valid.on_click(on_valid_click)
b_pred.on_click(on_pred_click)

start_ctrl = WidgetControl(widget=widgets.VBox([b_init], layout=widgets.Layout(padding='5px', background_color='white', border_radius='8px')), position='bottomright')
main_ctrl = WidgetControl(widget=widgets.HBox([b_calib, b_valid, b_pred], layout=widgets.Layout(padding='5px', background_color='white', border_radius='8px')), position='bottomright')

def activate_analysis(b):
    out_log.clear_output()
    # JavaScript trick to auto-expand output cell height to avoid scrollbars
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 5000});'''))

    with out_log: print("Initializing...")
    Map.addLayer(predictors_polygons.style(**{'color': 'gray', 'fillColor': '00000000'}), {}, 'Study Area')
    try:
        Map.remove_control(start_ctrl)
        Map.add_control(main_ctrl)

        # Add persistent legend to bottom left
        legend_widget = create_legend_widget()
        legend_control = WidgetControl(widget=legend_widget, position='bottomleft')
        Map.add_control(legend_control)
    except Exception as e: print(e)

b_init.on_click(activate_analysis)
Map.add_control(start_ctrl)

display(Map)
display(dashboard_ui)