In [7]:
# Setup
import sys
from pathlib import Path
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import FocalLoss
import torch.nn as nn
import torch

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.data.datasets import CNNIberFireDataset
from torch.utils.data import DataLoader

ZARR_PATH = project_root / "data" / "silver" / "IberFire.zarr"

In [19]:
from src.data.datasets import SimpleIberFireSegmentationDataset
from torch.utils.data import DataLoader

feature_vars = [
    "wind_speed_mean",
    "t2m_mean",
    "RH_mean",
    "total_precipitation_mean",
]
in_channels = len(feature_vars)

train_ds = SimpleIberFireSegmentationDataset(
    zarr_path=ZARR_PATH,
    time_start="2018-01-01",
    time_end="2020-12-31",
    feature_vars=feature_vars,
    label_var="is_near_fire",
    spatial_downsample=4,
    lead_time=1,         # predict tomorrow
    compute_stats=True,  # or precompute & pass stats
)

train_loader = DataLoader(
    train_ds,
    batch_size=4,        # start small, check memory
    shuffle=True,
    num_workers=4,
)

[SimpleDataset] Opening Zarr dataset: /Users/vladimir/catalonia-wildfire-prediction/data/silver/IberFire.zarr
[SimpleDataset] Filtering time range: 2018-01-01 to 2020-12-31
[SimpleDataset] Total usable time steps: 1096
[SimpleDataset] Computing normalization stats from data...
[SimpleDataset] wind_speed_mean: mean=2.4198, std=1.1674
[SimpleDataset] t2m_mean: mean=14.4443, std=7.4886
[SimpleDataset] RH_mean: mean=68.8383, std=16.8957
[SimpleDataset] total_precipitation_mean: mean=0.9676, std=2.5863


In [20]:
X, y = train_ds[0]
print(f"Feature tensor shape: {X.shape}, Label tensor shape: {y.shape}")
print(train_ds.get_time_value(0))


Feature tensor shape: torch.Size([4, 230, 297]), Label tensor shape: torch.Size([1, 230, 297])
2018-01-01T00:00:00.000000000


In [33]:
model = smp.Unet(
    encoder_name="resnet34",      # or "timm-efficientnet-b0", etc.
    encoder_weights="imagenet",          # or None if you don't want pretrained
    in_channels=in_channels,               # IberFire: number of feature channels per pixel
    classes=1,                    # 1 output channel for fire / no-fire probability
    activation=None               # we'll apply sigmoid later in the loss/metrics
)

# Example forward
# x: (batch_size, 64, H, W) -> logits: (batch_size, 1, H, W)

In [24]:
# sanity check
X, y = train_ds[0]                      # X: [C, H, W], y: [1, H, W]
X = X.unsqueeze(0)                      # [1, C, H, W]
with torch.no_grad():
    out = model(X)

print("Input :", X.shape)               # [1, 4, H, W]
print("Target:", y.unsqueeze(0).shape)  # [1, 1, H, W]
print("Output:", out.shape)            # should be [1, 1, H, W]

Input : torch.Size([1, 4, 230, 297])
Target: torch.Size([1, 1, 230, 297])
Output: torch.Size([1, 1, 230, 297])


In [11]:
# test dataset
test_ds = SimpleIberFireSegmentationDataset(
    zarr_path=ZARR_PATH,
    time_start="2021-01-01",
    time_end="2021-12-31",
    feature_vars=feature_vars,
    label_var="is_near_fire",
    spatial_downsample=4,
    lead_time=1,
    compute_stats=True,
)
test_loader = DataLoader(
    test_ds,
    batch_size=4,
    shuffle=False,
    num_workers=4,
)

[SimpleDataset] Opening Zarr dataset: /Users/vladimir/catalonia-wildfire-prediction/data/silver/IberFire.zarr
[SimpleDataset] Filtering time range: 2021-01-01 to 2021-12-31
[SimpleDataset] Total usable time steps: 365
[SimpleDataset] Computing normalization stats from data...
[SimpleDataset] wind_speed_mean: mean=2.4186, std=1.1589
[SimpleDataset] t2m_mean: mean=14.0974, std=7.1943
[SimpleDataset] RH_mean: mean=68.3215, std=16.0762
[SimpleDataset] total_precipitation_mean: mean=1.0134, std=2.3897


In [12]:
import tqdm

In [34]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
model = model.to(device)
#criterion = FocalLoss(mode="binary")
pos_weight = torch.tensor([10.0], device=device)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

Using device: mps


In [32]:
# model = torch.compile(model, mode="max-autotune-no-cudagraphs")

In [28]:
num_epochs = 10

In [35]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols = 100)
    for X_batch, y_batch in pbar:
        X_batch = X_batch.to(device).float()
        y_batch = y_batch.to(device).float()

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * X_batch.size(0)

        pbar.set_postfix({"loss": loss.item()})

    train_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")

Epoch 1/10: 100%|█████████████████████████████████████| 548/548 [15:35<00:00,  1.71s/it, loss=0.245]


Epoch 1/10, Training Loss: 0.3919


Epoch 2/10: 100%|█████████████████████████████████████| 548/548 [10:00<00:00,  1.10s/it, loss=0.231]


Epoch 2/10, Training Loss: 0.2784


Epoch 3/10: 100%|█████████████████████████████████████| 548/548 [10:08<00:00,  1.11s/it, loss=0.125]


Epoch 3/10, Training Loss: 0.2600


Epoch 4/10: 100%|█████████████████████████████████████| 548/548 [10:53<00:00,  1.19s/it, loss=0.455]


Epoch 4/10, Training Loss: 0.2526


Epoch 5/10: 100%|█████████████████████████████████████| 548/548 [10:16<00:00,  1.13s/it, loss=0.468]


Epoch 5/10, Training Loss: 0.2488


Epoch 6/10: 100%|█████████████████████████████████████| 548/548 [10:21<00:00,  1.13s/it, loss=0.523]


Epoch 6/10, Training Loss: 0.2461


Epoch 7/10: 100%|█████████████████████████████████████| 548/548 [10:24<00:00,  1.14s/it, loss=0.209]


Epoch 7/10, Training Loss: 0.2442


Epoch 8/10: 100%|█████████████████████████████████████| 548/548 [11:54<00:00,  1.30s/it, loss=0.298]


Epoch 8/10, Training Loss: 0.2435


Epoch 9/10: 100%|█████████████████████████████████████| 548/548 [19:17<00:00,  2.11s/it, loss=0.176]


Epoch 9/10, Training Loss: 0.2420


Epoch 10/10: 100%|████████████████████████████████████| 548/548 [18:05<00:00,  1.98s/it, loss=0.297]

Epoch 10/10, Training Loss: 0.2411





In [38]:
# save model 
import os
model_path = project_root / "models" / "unet_iberfire.pth"
os.makedirs(model_path.parent, exist_ok=True)
torch.save(model.state_dict(), model_path)

# Load model

In [9]:
# load model
import os
from pathlib import Path
import torch
# Setup
import sys
from pathlib import Path
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import FocalLoss
import torch.nn as nn
import torch

from src.data.datasets import SimpleIberFireSegmentationDataset
from torch.utils.data import DataLoader

feature_vars = [
    "wind_speed_mean",
    "t2m_mean",
    "RH_mean",
    "total_precipitation_mean",
]
in_channels = len(feature_vars)

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.data.datasets import CNNIberFireDataset
from torch.utils.data import DataLoader

ZARR_PATH = project_root / "data" / "silver" / "IberFire.zarr"

project_root = Path.cwd().parent
model_path = project_root / "models" / "unet_iberfire.pth"
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=in_channels,
    classes=1,
    activation=None,
)
device = "mps"
model.load_state_dict(torch.load(model_path, map_location=device))

<All keys matched successfully>

In [13]:
model.eval()
# Generate ONE prediction: a heatmap of probabilities
with torch.no_grad():
    X_test, y_test = test_ds[0]
    X_test = X_test.unsqueeze(0).to(device).float()
    y_test = y_test.unsqueeze(0).to(device).float()
    logits = model(X_test)
    probs = torch.sigmoid(logits)
    print(f"Predicted probabilities shape: {probs.shape}")

import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
probs_np = probs.squeeze().cpu().numpy()
plt.imshow(probs_np, cmap="hot")

RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same