
# Hotspot Prediction with GEE + LSTM (Weekly)

This notebook shows how to:
1. Initialize Google Earth Engine (GEE) in Python.
2. Build **weekly features** over an AOI (default: Nan Province, Thailand).
3. Train an **LSTM** to predict weekly hotspot counts (from VIIRS FIRMS).
4. Evaluate and forecast **2025** hotspots.


## 1) Setup

In [None]:

# If running in Colab, uncomment:
# !pip install -q earthengine-api geemap pandas numpy scikit-learn tensorflow==2.16.1 matplotlib

import math, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import ee, geemap

from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from tensorflow import keras
from tensorflow.keras import layers


## 2) Authenticate & Initialize GEE

In [None]:

try:
    ee.Initialize()
except Exception:
    ee.Authenticate()
    ee.Initialize()
print("GEE initialized")


## 3) Parameters

In [None]:

COUNTRY = 'Thailand'
PROVINCE = 'Nan'
START_TRAIN = '2020-01-01'
END_TRAIN   = '2024-12-31'
PRED_YEAR   = 2025
SEQ_LEN = 8
TARGET_MODE = 'count'  # or 'binary'
USE_VIIRS = True       # True: VIIRS (NASA/LANCE/SNPP_VIIRS/C2), False: MODIS FIRMS (FIRMS)


### 3.1 AOI from GAUL (with fallback)

In [None]:

def get_aoi_from_gaul(country, province):
    try:
        admin1 = ee.FeatureCollection('FAO/GAUL_SIMPLIFIED_500m/2015/level1')
        aoi = (admin1
               .filter(ee.Filter.eq('ADM0_NAME', country))
               .filter(ee.Filter.eq('ADM1_NAME', province))
               .geometry())
        _ = aoi.area(1).getInfo()  # trigger error if not found
        return aoi
    except Exception as e:
        print('GAUL lookup failed, using fallback bbox. Error:', e)
        return ee.Geometry.Rectangle([100.2, 18.0, 101.5, 19.5])

AOI = get_aoi_from_gaul(COUNTRY, PROVINCE)
m = geemap.Map(center=[18.8, 100.8], zoom=7)
m.addLayer(AOI, {'color':'red'}, 'AOI')
m


## 4) Collections & Helpers

In [None]:

viirs = ee.ImageCollection('NASA/LANCE/SNPP_VIIRS/C2')
firms_modis = ee.ImageCollection('FIRMS')
chirps = ee.ImageCollection('UCSB-CHG/CHIRPS/DAILY')
mod13q1 = ee.ImageCollection('MODIS/061/MOD13Q1')
mod11a2 = ee.ImageCollection('MODIS/061/MOD11A2')
srtm = ee.Image('USGS/SRTMGL1_003')
slope = ee.Terrain.slope(srtm).rename('slope')

def week_edges(start_date, end_date):
    start = ee.Date(start_date)
    end = ee.Date(end_date)
    dates = []
    d = start
    while d.millis().getInfo() < end.millis().getInfo():
        dates.append(d.format('YYYY-MM-dd').getInfo())
        d = d.advance(1, 'week')
    return dates

def viirs_week_image(start, end):
    ic = viirs.filterDate(start, end)
    fire = ic.map(lambda img: img.select('confidence').gte(1).rename('fire'))
    return fire.sum().rename('fire')

def firms_modis_week_image(start, end):
    ic = firms_modis.filterDate(start, end)
    fire = ic.map(lambda img: img.select('confidence').gte(30).rename('fire'))
    return fire.sum().rename('fire')

def chirps_week_sum(start, end):
    return chirps.filterDate(start, end).select('precipitation').sum().rename('precip')

def ndvi_week_mean(start, end):
    # NDVI scaled by 0.0001
    return mod13q1.filterDate(start, end).select('NDVI').mean().multiply(0.0001).rename('ndvi')

def lst_week_mean_c(start, end):
    lstK = mod11a2.filterDate(start, end).select('LST_Day_1km').mean().multiply(0.02)
    return lstK.subtract(273.15).rename('lst_c')

def region_mean(img, region, scale):
    return img.reduceRegion(ee.Reducer.mean(), region, scale, maxPixels=1e13)

def region_sum(img, region, scale):
    return img.reduceRegion(ee.Reducer.sum(), region, scale, maxPixels=1e13)


## 5) Build Weekly Feature Table (server-side → pandas)

In [None]:

def build_weekly_features(aoi, start_train, end_train, pred_year):
    full_end = f"{pred_year}-12-31"
    weekly_starts = week_edges(start_train, full_end)

    slope_mean = region_mean(slope, aoi, 90).get('slope')
    rows = []
    for ds in weekly_starts:
        d = ee.Date(ds)
        start = d
        end = d.advance(1, 'week')

        fire_img = viirs_week_image(start, end) if USE_VIIRS else firms_modis_week_image(start, end)
        fire_sum = region_sum(fire_img, aoi, 375 if USE_VIIRS else 1000).get('fire')

        precip = region_mean(chirps_week_sum(start, end), aoi, 5500).get('precip')
        ndvi = region_mean(ndvi_week_mean(start, end), aoi, 250).get('ndvi')
        lstc = region_mean(lst_week_mean_c(start, end), aoi, 1000).get('lst_c')

        feat = ee.Feature(None, {
            'date': d.format('YYYY-MM-dd'),
            'year': d.get('year'),
            'week': d.get('week'),
            'viirs_fire_count': fire_sum,
            'fire_binary': ee.Number(fire_sum).gt(0),
            'precip_mm': precip,
            'ndvi': ndvi,
            'lst_c': lstc,
            'slope_deg': slope_mean
        })
        rows.append(feat)

    fc = ee.FeatureCollection(rows)
    return geemap.ee_to_pandas(fc).sort_values('date')

df = build_weekly_features(AOI, START_TRAIN, END_TRAIN, PRED_YEAR)
df['date'] = pd.to_datetime(df['date'])
for c in ['year','week','viirs_fire_count','precip_mm','ndvi','lst_c','slope_deg']:
    df[c] = pd.to_numeric(df[c], errors='coerce')
df.head()


## 6) Target & Scaling

In [None]:

if TARGET_MODE == 'count':
    df['target'] = df['viirs_fire_count'].fillna(0.0)
    df['target_log1p'] = np.log1p(df['target'])
    y_col = 'target_log1p'
else:
    df['target_bin'] = (df['viirs_fire_count'] > 0).astype(int)
    y_col = 'target_bin'

feature_cols = ['precip_mm', 'ndvi', 'lst_c', 'slope_deg']

scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(df[feature_cols].values.astype('float32'))
df_scaled = df.copy()
df_scaled[feature_cols] = X_scaled


## 7) Train/Test Split & Sequence Builder

In [None]:

train_mask = df_scaled['date'] <= pd.to_datetime(END_TRAIN)
def make_sequences(frame, seq_len, y_column):
    Xs, ys, dates = [], [], []
    vals = frame[feature_cols].values.astype('float32')
    yvals = frame[y_column].values.astype('float32')
    for i in range(len(frame) - seq_len):
        Xs.append(vals[i:i+seq_len])
        ys.append(yvals[i+seq_len])
        dates.append(frame.iloc[i+seq_len]['date'])
    return np.array(Xs), np.array(ys), dates

X_train, y_train, dates_train = make_sequences(df_scaled[train_mask], SEQ_LEN, y_col)
X_test,  y_test,  dates_test  = make_sequences(df_scaled[~train_mask], SEQ_LEN, y_col)

X_train.shape, X_test.shape


## 8) LSTM Model

In [None]:

keras.backend.clear_session()
model = keras.Sequential([
    layers.Input(shape=(SEQ_LEN, len(feature_cols))),
    layers.LSTM(64),
    layers.Dropout(0.2),
    layers.Dense(32, activation='relu'),
    layers.Dense(1, activation='linear' if TARGET_MODE=='count' else 'sigmoid')
])
model.compile(optimizer=keras.optimizers.Adam(1e-3),
              loss='mse' if TARGET_MODE=='count' else 'binary_crossentropy',
              metrics=['mae'])
history = model.fit(X_train, y_train,
                    validation_data=(X_test, y_test) if len(X_test)>0 else None,
                    epochs=50, batch_size=16, verbose=1)
model.summary()


## 9) Evaluation

In [None]:

y_pred_test = model.predict(X_test).flatten() if len(X_test)>0 else np.array([])
if TARGET_MODE == 'count':
    y_test_lin = np.expm1(y_test)
    y_pred_lin = np.expm1(y_pred_test)
    mae = mean_absolute_error(y_test_lin, y_pred_lin) if len(y_pred_test)>0 else np.nan
    rmse = math.sqrt(mean_squared_error(y_test_lin, y_pred_lin)) if len(y_pred_test)>0 else np.nan
    r2 = r2_score(y_test_lin, y_pred_lin) if len(y_pred_test)>0 else np.nan
else:
    mae = mean_absolute_error(y_test, y_pred_test) if len(y_pred_test)>0 else np.nan
    rmse = math.sqrt(mean_squared_error(y_test, y_pred_test)) if len(y_pred_test)>0 else np.nan
    r2 = r2_score(y_test, y_pred_test) if len(y_pred_test)>0 else np.nan

print(f"MAE: {mae:.3f}, RMSE: {rmse:.3f}, R2: {r2:.3f}")
if len(y_pred_test)>0:
    plt.figure(figsize=(10,4))
    plt.plot(dates_test, y_test if TARGET_MODE!='count' else np.expm1(y_test), label='Actual')
    plt.plot(dates_test, y_pred_test if TARGET_MODE!='count' else np.expm1(y_pred_test), label='Predicted')
    plt.title('Weekly Hotspot Prediction (Test)')
    plt.xlabel('Date'); plt.ylabel('Fire count' if TARGET_MODE=='count' else 'Probability')
    plt.legend(); plt.grid(True); plt.show()


## 10) Forecast 2025

In [None]:

df_2025 = df_scaled[df_scaled['date'].dt.year == PRED_YEAR].copy()
def make_sequences_simple(frame, seq_len, y_column):
    if len(frame) <= seq_len:
        return np.empty((0, seq_len, len(feature_cols))), np.array([]), []
    Xs, ys, dates = [], [], []
    vals = frame[feature_cols].values.astype('float32')
    yvals = frame[y_column].values.astype('float32')
    for i in range(len(frame) - seq_len):
        Xs.append(vals[i:i+seq_len]); ys.append(yvals[i+seq_len]); dates.append(frame.iloc[i+seq_len]['date'])
    return np.array(Xs), np.array(ys), dates

X_2025, y_2025, dates_2025 = make_sequences_simple(df_2025, SEQ_LEN, y_col)
yhat_2025 = model.predict(X_2025).flatten() if len(X_2025)>0 else np.array([])

if TARGET_MODE == 'count':
    series = pd.DataFrame({'date': dates_2025, 'pred_fire_count': np.expm1(yhat_2025)})
else:
    series = pd.DataFrame({'date': dates_2025, 'pred_fire_prob': yhat_2025})

series_path = "/mnt/data/hotspot_2025_predictions.csv"
series.to_csv(series_path, index=False)
series.head(), series_path


## 11) Save Artifacts

In [None]:

model_path = "/mnt/data/lstm_hotspot_model.keras"
scaler_path = "/mnt/data/feature_scaler.npy"
model.save(model_path)
np.save(scaler_path, {'feature_cols': feature_cols,
                      'data_min_': getattr(scaler, 'data_min_', None),
                      'data_max_': getattr(scaler, 'data_max_', None)})
(model_path, scaler_path)



## 12) Notes & References

- VIIRS NRT Fires (375m): `ee.ImageCollection("NASA/LANCE/SNPP_VIIRS/C2")`
- FIRMS (MODIS) rasterized fires: `ee.ImageCollection("FIRMS")`
- CHIRPS Daily precipitation: `ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY")`
- MOD13Q1 NDVI v061 (16‑day): `ee.ImageCollection("MODIS/061/MOD13Q1")`
- MOD11A2 LST Day v061 (8‑day): `ee.ImageCollection("MODIS/061/MOD11A2")`
- SRTM DEM: `ee.Image("USGS/SRTMGL1_003")`
- TFRecord with Earth Engine: https://developers.google.com/earth-engine/guides/tfrecord

**Extensions**
- Switch to `TARGET_MODE='binary'` to predict presence/absence.
- Add more predictors: land cover (MCD12Q1), soil moisture (SMAP), wind (ERA5), drought indices.
- For pixel/patch‑level modeling and ConvLSTM, export gridded tensors to **TFRecord** and train on image sequences.
