In [None]:
import os
import glob
import numpy as np
from datetime import datetime
import random

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torchvision.ops import sigmoid_focal_loss
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pandas as pd
import xarray as xr
from skimage.util import view_as_windows
import geopandas as gpd
from shapely.geometry import Point

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.ndimage import label
from matplotlib import gridspec

## Architecture

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim,
                              kernel_size=kernel_size, padding=padding)

        self.hidden_dim = hidden_dim

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat([x, h_prev], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

class SqueezeExcite(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class ResConvLSTM(nn.Module):
    def __init__(self, in_channels, hidden_dim=32, patch_size=7, dropout=0.3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.patch_size = patch_size
        self.dropout = dropout

        self.input_conv = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.se = SqueezeExcite(hidden_dim)

        self.lstm_cell = ConvLSTMCell(hidden_dim, hidden_dim)

        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1x1 = nn.Conv2d(hidden_dim, 1, kernel_size=1)

    def forward(self, x):
        B, C, T, H, W = x.size()
        h, c = torch.zeros(B, self.hidden_dim, H, W, device=x.device), \
               torch.zeros(B, self.hidden_dim, H, W, device=x.device)

        for t in range(T):
            x_t = x[:, :, t, :, :]
            x_t = self.input_conv(x_t)
            x_t = self.bn(x_t)
            x_t = F.relu(x_t)
            x_t = self.se(x_t)

            h, c = self.lstm_cell(x_t, h, c)

        out = self.global_pool(h)
        out = self.conv1x1(out)
        out = out.view(B)
        return out

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim,
                              kernel_size=kernel_size, padding=padding)

        self.hidden_dim = hidden_dim

    def forward(self, x, h_prev, c_prev):
        combined = torch.cat([x, h_prev], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c = f * c_prev + i * g
        h = o * torch.tanh(c)
        return h, c

class SqueezeExcite(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = kernel_size // 2
        # takes max+avg pooled channels → 1-channel attention map
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # channel‐wise avg/max → each is (B,1,H,W)
        avg = torch.mean(x, dim=1, keepdim=True) # general trend across all pixels
        mx  = torch.max(x,  dim=1, keepdim=True)[0] # strong pixels
        attn = torch.cat([avg, mx], dim=1)
        attn = self.sigmoid(self.conv(attn))
        return x * attn

class ResConvLSTMwithAttention(nn.Module):
    def __init__(self, in_channels, hidden_dim=32, patch_size=7, dropout=0.3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.patch_size = patch_size
        self.dropout = dropout

        self.input_conv = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(hidden_dim)
        self.se = SqueezeExcite(hidden_dim)

        self.lstm_cell = ConvLSTMCell(hidden_dim, hidden_dim)

        self.spatial_attn = SpatialAttention()

        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1x1 = nn.Conv2d(hidden_dim, 1, kernel_size=1)

    def forward(self, x):
        B, C, T, H, W = x.size()
        h1, c1 = torch.zeros(B, self.hidden_dim, H, W, device=x.device), \
               torch.zeros(B, self.hidden_dim, H, W, device=x.device)
        h2, c2 = torch.zeros(B, self.hidden_dim, H, W, device=x.device), \
               torch.zeros(B, self.hidden_dim, H, W, device=x.device)

        for t in range(T):
            x_t = x[:, :, t, :, :]
            x_t = self.input_conv(x_t)
            x_t = self.bn(x_t)
            x_t = F.relu(x_t)
            x_t = self.se(x_t)
            x_t = self.spatial_attn(x_t)

            h, c = self.lstm_cell(x_t, h, c)
            h1, c1 = self.lstm1(x_t, h1, c1)      # 1st ConvLSTM
            h2, c2 = self.lstm2(h1, h2, c2)      # 2nd ConvLSTM

        out = self.global_pool(h2)       # B x hidden_dim x 1 x 1
        out = self.conv1x1(out)        # B x 1 x 1 x 1
        out = out.view(B)
        return out

## Dataset

In [None]:
class BurnPatchDataset(Dataset):
    def __init__(self, file_paths, normalize=True):
        self.file_paths = file_paths
        self.normalize = normalize

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        data = np.load(path)
        x = data["x"][:4]


        y = int(data["y"])

        if self.normalize:
            mean = np.nanmean(x, axis=(0, 1, 2), keepdims=True)
            std = np.nanstd(x, axis=(0, 1, 2), keepdims=True) + 1e-6
            x = (x - mean) / std

        x = np.transpose(x, (3, 0, 1, 2)).astype(np.float32)
        return torch.tensor(x), torch.tensor(y, dtype=torch.float32)

In [None]:
# 1. Extract date from filename
def extract_date_from_filename(filename):
    basename = os.path.basename(filename)
    date_str = basename.split('_')[1]  # "2024-06-22"
    return datetime.strptime(date_str, "%Y-%m-%d")

# 2. Define test date ranges
test_date_ranges = [
    (datetime(2024, 6, 15), datetime(2024, 6, 19))
]

# 3. Load all .npz files
all_files = glob.glob(os.path.join("Training_data/patches/", "*.npz"))

# 4. Separate files into test and train/val based on date ranges
test_files = []
train_val_files = []

# 5. Filter test date
for f in all_files:
    file_date = extract_date_from_filename(f)
    if any(start <= file_date <= end for (start, end) in test_date_ranges):
        test_files.append(f)
    else:
        train_val_files.append(f)

# 6. Extract label of each file
def get_label_from_file(path):
    return int(np.load(path)["y"])

# 7. Divide file into 0 and 1
ones = [f for f in train_val_files if get_label_from_file(f) == 1]
zeros = [f for f in train_val_files if get_label_from_file(f) == 0]

# 8. Do sampling 0 files with the same number of 1
num_ones = len(ones)
sampled_zeros = random.sample(zeros, 2*num_ones)

# 9. Merge 0 and 1 (with label)
labeled_ones = [(f, 1) for f in ones]
labeled_zeros = [(f, 0) for f in sampled_zeros]

balanced_labeled = labeled_ones + labeled_zeros
random.shuffle(balanced_labeled)

file_paths = [f for f, label in balanced_labeled]
labels = [label for f, label in balanced_labeled]

# 10. Stratified train/val split
train_files, val_files, train_labels, val_labels = train_test_split(
    file_paths, labels, test_size=0.1, random_state=42, stratify=labels
)

# 11. Dataset / DataLoader
train_ds = BurnPatchDataset(train_files)
val_ds = BurnPatchDataset(val_files)
test_ds = BurnPatchDataset(test_files)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=True)

In [None]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == "mean":
            return focal_loss.mean()
        else:
            return focal_loss.sum()

## ResConvLSTM

In [None]:
model = ResConvLSTM(in_channels=10, hidden_dim=64, patch_size=15)
model = model.cuda() if torch.cuda.is_available() else model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = FocalLoss(alpha=0.1, gamma=3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_loss, patience_counter = 10, 0
patience = 5
th = 0.4

from sklearn.metrics import f1_score, roc_auc_score, precision_score, recall_score  # 🔧

# Initialize history
history = {
    "train_loss": [],
    "val_loss": [],
}

for epoch in range(1, 200 + 1):
    print(f"\nEpoch {epoch}")
    model.train()
    all_preds, all_targets, losses = [], [], []

    for x, y in tqdm(train_dl, desc="Train", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        preds = (torch.sigmoid(logits) > th).float()
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(y.cpu().numpy())
        losses.append(loss.item())

    train_f1 = f1_score(all_targets, all_preds)
    train_precision = precision_score(all_targets, all_preds, zero_division=0)
    train_recall = recall_score(all_targets, all_preds, zero_division=0)
    print(f"Train loss: {np.mean(losses):.4f}")

    model.eval()
    all_preds, all_targets, probs, val_losses = [], [], [], []

    with torch.no_grad():
        for x, y in tqdm(val_dl, desc="Val", leave=False):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            prob = torch.sigmoid(logits)
            pred = (prob > th).float()

            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
            probs.extend(prob.cpu().numpy())
            val_losses.append(criterion(logits, y).item())

    val_f1 = f1_score(all_targets, all_preds)
    val_precision = precision_score(all_targets, all_preds, zero_division=0)
    val_recall = recall_score(all_targets, all_preds, zero_division=0)
    val_auc = roc_auc_score(all_targets, probs)
    print(f"Val loss: {np.mean(val_losses):.4f}")

    scheduler.step(criterion(logits, y).item())

    # Logging
    history["train_loss"].append(np.mean(losses))
    history["val_loss"].append(np.mean(val_losses))

    # if val_f1 > best_f1:
    if np.mean(val_losses) < best_loss:
        best_loss = np.mean(val_losses)
        patience_counter = 0
        torch.save(model.state_dict(), "/best_resconvlstm.pt")
        print("Saved new best model.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping.")
            break