In [38]:
# Cell 1: imports and config
import xarray as xr
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path

DATA_PATH = "data/iberfire.nc"  # adjust if needed
OUT_NPZ = Path("data/minimal_cnn_samples.npz")

# choose very few variables and short time range
FEATURE_VARS = ["CLC_2006_forest_proportion", "wind_speed_mean", "t2m_mean", "RH_mean", "total_precipitation_mean", "is_holiday"]  # change to actual names in ds
LABEL_VAR = "is_near_fire"

TIME_START = "2018-06-01"
TIME_END   = "2020-08-31" 
SPATIAL_DOWNSAMPLE = 4      # keep every 4th pixel in y and x

In [39]:
# Cell 2: open dataset & inspect
ds = xr.open_dataset(DATA_PATH)  # no chunks for now; we keep this small

print(ds[FEATURE_VARS + [LABEL_VAR]])

<xarray.Dataset> Size: 123GB
Dimensions:                     (y: 920, x: 1188, time: 6241)
Coordinates:
  * x                           (x) float64 10kB 2.675e+06 ... 3.862e+06
  * y                           (y) float64 7kB 2.492e+06 ... 1.573e+06
  * time                        (time) datetime64[ns] 50kB 2007-12-01 ... 202...
Data variables:
    CLC_2006_forest_proportion  (y, x) float32 4MB ...
    wind_speed_mean             (time, y, x) float32 27GB ...
    t2m_mean                    (time, y, x) float32 27GB ...
    RH_mean                     (time, y, x) float32 27GB ...
    total_precipitation_mean    (time, y, x) float32 27GB ...
    is_holiday                  (time, y, x) uint8 7GB ...
    is_near_fire                (time, y, x) uint8 7GB ...
Attributes: (12/17)
    title:                IberFire
    description:          Datacube centered in Spain with 1km x 1km spatial r...
    dimensions:           (y: 920, x: 1188, time: 6241)
    spatial_resolution:   1km x 1km
    t

In [37]:
# Cell 3: extract small tensor dataset
time_sel = ds.sel(time=slice(TIME_START, TIME_END))
print("Selected time steps:", time_sel.dims["time"])

Xs = []
ys = []

for t_idx in range(time_sel.dims["time"]):
    # select at time index
    frame = time_sel.isel(time=t_idx)

    # features: list of [H,W]
    feat_arrays = []
    for v in FEATURE_VARS:
        if v not in frame:
            raise ValueError(f"Variable {v} not found in dataset")
        arr = frame[v].values  # [H,W]
        feat_arrays.append(arr)

    # stack channels -> [C,H,W]
    X = np.stack(feat_arrays, axis=0)

    # label
    if LABEL_VAR not in frame:
        raise ValueError(f"Label variable {LABEL_VAR} not found")
    y = frame[LABEL_VAR].values.astype("float32")  # [H,W]

    # optional: spatial downsample
    X = X[:, ::SPATIAL_DOWNSAMPLE, ::SPATIAL_DOWNSAMPLE]
    y = y[::SPATIAL_DOWNSAMPLE, ::SPATIAL_DOWNSAMPLE]

    Xs.append(X)
    ys.append(y)

Xs = np.stack(Xs, axis=0)  # [N,C,H,W]
ys = np.stack(ys, axis=0)  # [N,H,W]

print("Xs shape:", Xs.shape, "ys shape:", ys.shape)

  print("Selected time steps:", time_sel.dims["time"])
  for t_idx in range(time_sel.dims["time"]):


Selected time steps: 823


KeyboardInterrupt: 

In [40]:
# Cell 4: simple normalization + save to disk
# normalize each channel by global mean/std over this mini-dataset
C = Xs.shape[1]
for c in range(C):
    mean = Xs[:, c].mean()
    std = Xs[:, c].std()
    if std < 1e-6:
        std = 1.0
    Xs[:, c] = (Xs[:, c] - mean) / std

# binarize label if needed (0/1)
ys_bin = (ys > 0.5).astype("float32")

OUT_NPZ.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(OUT_NPZ, X=Xs, y=ys_bin)
print("Saved to", OUT_NPZ)

AttributeError: 'list' object has no attribute 'shape'

In [41]:
# Cell 5: PyTorch dataset for the npz file
class MinimalIberFireDataset(Dataset):
    def __init__(self, npz_path):
        data = np.load(npz_path)
        self.X = data["X"]          # [N,C,H,W]
        self.y = data["y"]          # [N,H,W]

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        X = torch.from_numpy(self.X[idx]).float()        # [C,H,W]
        y = torch.from_numpy(self.y[idx]).float()        # [H,W]
        return X, y.unsqueeze(0)                         # [1,H,W]

dataset = MinimalIberFireDataset(OUT_NPZ)
print("Dataset size:", len(dataset))
X0, y0 = dataset[0]
print("Sample X shape:", X0.shape, "y shape:", y0.shape)

# Cell 6: create train/val splits and DataLoaders
from torch.utils.data import random_split

N = len(dataset)
n_train = int(0.8 * N)
n_val = N - n_train

train_ds, val_ds = random_split(dataset, [n_train, n_val])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False)

Dataset size: 457
Sample X shape: torch.Size([6, 64, 71]) y shape: torch.Size([1, 64, 71])


In [42]:
# Cell 7: tiny CNN model
class TinyFireCNN(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 1),  # logits
        )

    def forward(self, x):
        return self.net(x)  # [B,1,H,W]

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

in_channels = X0.shape[0]
model = TinyFireCNN(in_channels).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(model)

TinyFireCNN(
  (net): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)


In [43]:
# Cell 8: training loop
def pixel_metrics(logits, y, thr=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()
        y = y.float()
        tp = (preds * y).sum().item()
        fp = (preds * (1 - y)).sum().item()
        fn = ((1 - preds) * y).sum().item()
        tn = (((1 - preds) * (1 - y))).sum().item()
        eps = 1e-8
        prec = tp / (tp + fp + eps)
        rec = tp / (tp + fn + eps)
        f1 = 2 * prec * rec / (prec + rec + eps)
        acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return acc, prec, rec, f1

EPOCHS = 5

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0.0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= max(1, len(train_loader))

    model.eval()
    val_loss = 0.0
    acc_sum = prec_sum = rec_sum = f1_sum = 0.0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = criterion(logits, y)
            val_loss += loss.item()
            a, p, r, f1 = pixel_metrics(logits, y)
            acc_sum += a; prec_sum += p; rec_sum += r; f1_sum += f1
    val_loss /= max(1, len(val_loader))
    acc = acc_sum / max(1, len(val_loader))
    prec = prec_sum / max(1, len(val_loader))
    rec = rec_sum / max(1, len(val_loader))
    f1 = f1_sum / max(1, len(val_loader))

    print(f"Epoch {epoch} | train_loss={train_loss:.4f} "
          f"val_loss={val_loss:.4f} acc={acc:.3f} prec={prec:.3f} rec={rec:.3f} f1={f1:.3f}")

Epoch 1 | train_loss=0.6593 val_loss=0.5962 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 2 | train_loss=0.5094 val_loss=0.4145 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 3 | train_loss=0.3222 val_loss=0.2393 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 4 | train_loss=0.1803 val_loss=0.1345 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 5 | train_loss=0.1037 val_loss=0.0835 acc=0.995 prec=0.000 rec=0.000 f1=0.000


In [44]:
# Check fire prevalence in the mini dataset
import numpy as np

data = np.load("data/minimal_cnn_samples.npz")
y = data["y"]  # [N,H,W], already binarized

pos = (y == 1).sum()
neg = (y == 0).sum()
print("Total pixels:", y.size, "positives:", pos, "negatives:", neg, "pos_ratio:", pos / y.size)

Total pixels: 2076608 positives: 8787 negatives: 2067821 pos_ratio: 0.004231419699818165
