<a href="https://colab.research.google.com/github/dookda/cmu_udfire_gee/blob/main/predict_hp_using_lstm_gee_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install -U geemap

# !pip -q install -U geemap earthengine-api pandas numpy scikit-learn tensorflow geopandas shapely rasterio


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/631.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m631.5/631.5 kB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m85.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import ee
import folium
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dropout, Dense, Bidirectional
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.regularizers import l2
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import ee, geemap, time
from datetime import datetime

ee.Authenticate()
try:
    ee.Initialize(project="ee-sakda-451407")
except Exception as e:
    ee.Authenticate()
    ee.Initialize(project="ee-sakda-451407")

In [2]:
import ee
import geemap
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from datetime import datetime, timedelta
import folium
from folium.raster_layers import ImageOverlay
from PIL import Image
import io
import matplotlib.pyplot as plt

# Authenticate and initialize Earth Engine (run this once)
# ee.Authenticate()
ee.Authenticate()
try:
    ee.Initialize(project="ee-sakda-451407")
except Exception as e:
    ee.Authenticate()
    ee.Initialize(project="ee-sakda-451407")

In [7]:
# Define the area of interest (approximate bounding box for Tambon Suthep)
min_lon = 98.89
min_lat = 18.77
max_lon = 98.98
max_lat = 18.82
aoi = ee.Geometry.Rectangle([min_lon, min_lat, max_lon, max_lat])

# MODIS collection for fire data
collection = ee.ImageCollection('MODIS/061/MOD14A1') \
    .filterDate('2020-01-01', '2024-12-31') \
    .filterBounds(aoi)

# Function to create weekly composites (max fire mask)
def weekly_composite(start_date):
    start = ee.Date(start_date)
    end = start.advance(7, 'day')
    weekly_coll = collection.filterDate(start, end)

    # Check if the collection is empty
    count = weekly_coll.size().getInfo()
    if count == 0:
        print(f"No data for week starting {start_date}, returning zero image")
        return ee.Image.constant(0).rename('FireMask').clip(aoi).reproject(crs='EPSG:4326', scale=1000)

    weekly = weekly_coll.max()
    # Check if the image has the FireMask band
    band_names = weekly.bandNames().getInfo()
    if 'FireMask' not in band_names:
        print(f"No FireMask band for week starting {start_date}, returning zero image")
        return ee.Image.constant(0).rename('FireMask').clip(aoi).reproject(crs='EPSG:4326', scale=1000)

    weekly = weekly.select('FireMask').clip(aoi).reproject(crs='EPSG:4326', scale=1000)
    # Mask fire: 1 if fire (FireMask >=7), 0 otherwise
    fire = weekly.gte(7)
    return fire.set('system:time_start', start.millis())

# Generate list of start dates for weeks from 2020 to 2024
start_date = datetime(2020, 1, 1)
end_date = datetime(2024, 12, 31)
weekly_dates = []
current = start_date
while current < end_date:
    weekly_dates.append(current.strftime('%Y-%m-%d'))
    current += timedelta(days=7)

# Create weekly image collection (server-side)
weekly_images = [weekly_composite(date) for date in weekly_dates]
weekly_collection = ee.ImageCollection.fromImages(weekly_images)

# Get spatial dimensions (approximate, based on 1km resolution)
scale = 1000  # 1km in meters
dx = (max_lon - min_lon) / (scale / 111000)  # Approx 1 deg = 111km
dy = (max_lat - min_lat) / (scale / 111000)
ncols = int(np.ceil((max_lon - min_lon) / dx))
nrows = int(np.ceil((max_lat - min_lat) / dy))

# Download all weekly data as numpy arrays
data = []
image_list = weekly_collection.toList(weekly_collection.size())
for i in range(len(weekly_dates)):
    try:
        image = ee.Image(image_list.get(i))
        # Use date as a fallback identifier for logging
        img_id = weekly_dates[i]
        array = geemap.ee_to_numpy(image, region=aoi, scale=1000)
        if array is not None and array.shape[0] > 0 and array.shape[1] > 0:
            data.append(array[:, :, 0])  # Single band (FireMask)
        else:
            print(f"Warning: Empty array for week {img_id}, using zero array")
            data.append(np.zeros((nrows, ncols)))
    except Exception as e:
        print(f"Error processing week {img_id}: {str(e)}")
        data.append(np.zeros((nrows, ncols)))

# Ensure consistent shape
data = [d[:nrows, :ncols] for d in data if d.shape == (nrows, ncols)]
if len(data) == 0:
    raise ValueError("No valid data retrieved. Check MODIS data availability or AOI.")
data = np.stack(data, axis=0)  # Shape: (time, height, width)

# Data preparation for CNN-LSTM
seq_length = 4
X, y = [], []
for i in range(len(data) - seq_length):
    X.append(data[i:i+seq_length])
    y.append(data[i+seq_length])
X = np.array(X)  # Shape: (samples, seq, h, w)
y = np.array(y)  # Shape: (samples, h, w)
X = np.expand_dims(X, axis=-1)  # Add channel: (samples, seq, h, w, 1)

# Split train/test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Dataset
class FireDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = FireDataset(X_train, y_train)
test_ds = FireDataset(X_test, y_test)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=8)

# CNN-LSTM Model
class CNNLSTM(nn.Module):
    def __init__(self, height, width):
        super(CNNLSTM, self).__init__()
        self.conv = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=(seq_length, 3, 3), padding=(0, 1, 1))
        self.pool = nn.MaxPool3d((1, 2, 2))
        conv_output_size = 16 * (height // 2) * (width // 2)
        self.lstm = nn.LSTM(input_size=conv_output_size, hidden_size=128, num_layers=1, batch_first=True)
        self.fc = nn.Linear(128, height * width)
        self.height = height
        self.width = width

    def forward(self, x):
        x = x.permute(0, 4, 1, 2, 3)  # (batch, seq, h, w, ch) -> (batch, ch, seq, h, w)
        x = self.conv(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Flatten
        x, _ = self.lstm(x.unsqueeze(1))
        x = self.fc(x.squeeze(1))
        x = torch.sigmoid(x)  # Binary output
        return x.view(-1, self.height, self.width)

model = CNNLSTM(nrows, ncols)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train
epochs = 20
for epoch in range(epochs):
    model.train()
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Evaluate
model.eval()
y_pred = []
y_true = []
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = model(inputs)
        pred = (outputs > 0.5).float()
        y_pred.append(pred.numpy().flatten())
        y_true.append(targets.numpy().flatten())

y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)

# Evaluation metrics
print('Accuracy:', accuracy_score(y_true, y_pred))
print('Precision:', precision_score(y_true, y_pred, average='macro', zero_division=0))
print('Recall:', recall_score(y_true, y_pred, average='macro', zero_division=0))
print('F1:', f1_score(y_true, y_pred, average='macro', zero_division=0))

# Predict next week
last_seq = data[-seq_length:]  # Last 4 weeks
last_seq = np.expand_dims(last_seq, axis=(0, -1))  # Shape: (1, seq, h, w, 1)
last_seq = torch.tensor(last_seq, dtype=torch.float32)
with torch.no_grad():
    next_pred = model(last_seq)
    next_pred = (next_pred > 0.5).float().squeeze().numpy()  # Shape: (h, w)

# Create Folium map
m = folium.Map(location=[(min_lat + max_lat)/2, (min_lon + max_lon)/2], zoom_start=13)

# Create image from prediction (binary, black/white)
pred_img = (next_pred * 255).astype(np.uint8)
img = Image.fromarray(pred_img, mode='L')
img_data = io.BytesIO()
img.save(img_data, 'PNG')
img_data.seek(0)

# Bounds for overlay
bounds = [[min_lat, min_lon], [max_lat, max_lon]]

ImageOverlay(
    image=img_data,
    bounds=bounds,
    opacity=0.6,
).add_to(m)

# Save or display map
# m.save('next_week_hotspot_prediction.html')
# print('Folium map saved as next_week_hotspot_prediction.html')
m

No data for week starting 2022-10-12, returning zero image


ValueError: No valid data retrieved. Check MODIS data availability or AOI.