In [11]:
# Cell 1: imports and config
import xarray as xr
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import pandas as pd

DATA_PATH = "data/iberfire.nc"  # adjust if needed
# NEW: output directory structure instead of single file
OUT_DIR = Path("data/minimal_cnn_sharded")
OUT_DIR.mkdir(parents=True, exist_ok=True)
# choose very few variables and short time range
FEATURE_VARS = ["CLC_2006_forest_proportion", "wind_speed_mean", "t2m_mean", "RH_mean", "total_precipitation_mean", "is_holiday"]  # change to actual names in ds
LABEL_VAR = "is_near_fire"
SAMPLES_PER_SHARD = 64  # write every 64 time steps
TIME_START = "2018-06-01"
TIME_END   = "2020-08-31" 
SPATIAL_DOWNSAMPLE = 4      # keep every 4th pixel in y and x

In [12]:
# Cell 2: open dataset & inspect
ds = xr.open_dataset(DATA_PATH, chunks={'time': 64, 'y': 256, 'x': 256})  # use dask for lazy loading

print(ds[FEATURE_VARS + [LABEL_VAR]])

<xarray.Dataset> Size: 123GB
Dimensions:                     (y: 920, x: 1188, time: 6241)
Coordinates:
  * x                           (x) float64 10kB 2.675e+06 ... 3.862e+06
  * y                           (y) float64 7kB 2.492e+06 ... 1.573e+06
  * time                        (time) datetime64[ns] 50kB 2007-12-01 ... 202...
Data variables:
    CLC_2006_forest_proportion  (y, x) float32 4MB dask.array<chunksize=(256, 256), meta=np.ndarray>
    wind_speed_mean             (time, y, x) float32 27GB dask.array<chunksize=(64, 256, 256), meta=np.ndarray>
    t2m_mean                    (time, y, x) float32 27GB dask.array<chunksize=(64, 256, 256), meta=np.ndarray>
    RH_mean                     (time, y, x) float32 27GB dask.array<chunksize=(64, 256, 256), meta=np.ndarray>
    total_precipitation_mean    (time, y, x) float32 27GB dask.array<chunksize=(64, 256, 256), meta=np.ndarray>
    is_holiday                  (time, y, x) uint8 7GB dask.array<chunksize=(64, 256, 256), meta=np.ndarr

  ds = xr.open_dataset(DATA_PATH, chunks={'time': 64, 'y': 256, 'x': 256})  # use dask for lazy loading
  ds = xr.open_dataset(DATA_PATH, chunks={'time': 64, 'y': 256, 'x': 256})  # use dask for lazy loading
  ds = xr.open_dataset(DATA_PATH, chunks={'time': 64, 'y': 256, 'x': 256})  # use dask for lazy loading


In [13]:
# Cell 3: extract small tensor dataset

time_sel = ds.sel(time=slice(TIME_START, TIME_END))
print(f"Selected time steps: {time_sel.dims['time']}")

# NEW: compute stats efficiently using xarray reductions
print("Computing normalization stats...")
stats = {}
for v in FEATURE_VARS:
    var = time_sel[v]
    
    # handle spatial-only variables (no time dimension)
    if "time" not in var.dims:
        # broadcast to time dimension for consistent stats computation
        var = var.expand_dims(time=time_sel["time"])
    
    # forward/backward fill to handle NaNs, then compute stats
    var_filled = var.ffill("time").bfill("time")
    
    # compute mean and std (xarray will handle chunked computation)
    mean = float(var_filled.mean(skipna=True).compute())
    std = float(var_filled.std(skipna=True).compute())
    
    stats[v] = {
        "mean": mean,
        "std": std if std > 1e-6 else 1.0
    }
    print(f"  {v}: mean={mean:.4f}, std={std:.4f}")

# NEW: save stats for later use
import json
with open(OUT_DIR / "stats.json", "w") as f:
    json.dump(stats, f, indent=2)

print(f"Stats saved to {OUT_DIR / 'stats.json'}")


  print(f"Selected time steps: {time_sel.dims['time']}")


Selected time steps: 823
Computing normalization stats...
  CLC_2006_forest_proportion: mean=0.1537, std=0.2813
  wind_speed_mean: mean=2.3934, std=1.1843
  t2m_mean: mean=15.2473, std=7.3298
  RH_mean: mean=67.4048, std=17.0030
  total_precipitation_mean: mean=0.9370, std=2.5542
  is_holiday: mean=0.2993, std=0.4579
Stats saved to data/minimal_cnn_sharded/stats.json


In [None]:
# Cell 4: extract and write shards incrementally

train_dir = OUT_DIR / "train"
train_dir.mkdir(parents=True, exist_ok=True)

shard_idx = 0
X_bucket = []
y_bucket = []
manifest = []

print(f"Processing {time_sel.dims['time']} time steps...")

for t_idx in range(time_sel.dims['time']):
    # Progress indicator every 50 steps
    if t_idx % 50 == 0:
        print(f"  Processing time step {t_idx}/{time_sel.dims['time']}...")
    
    # Extract single time slice (small, fits in RAM)
    frame = time_sel.isel(time=t_idx)
    
    # Extract features -> [C,H,W]
    feat_arrays = []
    for v in FEATURE_VARS:
        if v not in frame:
            raise ValueError(f"Variable {v} not found in dataset")
        arr = frame[v].values  # load small slice into RAM
        feat_arrays.append(arr)
    
    X = np.stack(feat_arrays, axis=0)  # [C,H,W]
    
    # Normalize using pre-computed stats
    for c, v in enumerate(FEATURE_VARS):
        X[c] = (X[c] - stats[v]["mean"]) / stats[v]["std"]
    
    # Spatial downsample
    X = X[:, ::SPATIAL_DOWNSAMPLE, ::SPATIAL_DOWNSAMPLE]
    
    # Extract label
    if LABEL_VAR not in frame:
        raise ValueError(f"Label variable {LABEL_VAR} not found")
    y = frame[LABEL_VAR].values.astype("float32")  # [H,W]
    y = y[::SPATIAL_DOWNSAMPLE, ::SPATIAL_DOWNSAMPLE]
    
    # Binarize
    y_bin = (y > 0.5).astype("float32")
    
    # Add to bucket
    X_bucket.append(X)
    y_bucket.append(y_bin)
    
    # Write shard when bucket is full
    if len(X_bucket) == SAMPLES_PER_SHARD:
        shard_path = train_dir / f"shard_{shard_idx:06d}.npz"
        np.savez_compressed(
            shard_path,
            X=np.stack(X_bucket, axis=0),
            y=np.stack(y_bucket, axis=0)
        )
        manifest.append({
            "shard": shard_path.name,
            "num_samples": len(X_bucket)
        })
        print(f"    Wrote {shard_path.name}")
        
        # Clear bucket
        X_bucket = []
        y_bucket = []
        shard_idx += 1

# Write leftover samples
if X_bucket:
    shard_path = train_dir / f"shard_{shard_idx:06d}.npz"
    np.savez_compressed(
        shard_path,
        X=np.stack(X_bucket, axis=0),
        y=np.stack(y_bucket, axis=0)
    )
    manifest.append({
        "shard": shard_path.name,
        "num_samples": len(X_bucket)
    })
    print(f"    Wrote {shard_path.name} (final)")

# Write manifest
pd.DataFrame(manifest).to_parquet(train_dir / "manifest.parquet", index=False)
print(f"\nDone! Wrote {shard_idx + 1} shards to {train_dir}")
print(f"Total samples: {sum(m['num_samples'] for m in manifest)}")

Processing 823 time steps...
  Processing time step 0/823...


  print(f"Processing {time_sel.dims['time']} time steps...")
  for t_idx in range(time_sel.dims['time']):
  print(f"  Processing time step {t_idx}/{time_sel.dims['time']}...")


  Processing time step 50/823...
    Wrote shard_000000.npz
  Processing time step 100/823...
    Wrote shard_000001.npz
  Processing time step 150/823...
    Wrote shard_000002.npz
  Processing time step 200/823...


In [41]:
# Cell 5: PyTorch Dataset for sharded data

class ShardedIberFireDataset(Dataset):
    """Streams samples from multiple .npz shards without loading all into RAM."""
    
    def __init__(self, split_dir):
        self.split_dir = Path(split_dir)
        
        # Load manifest to know which shards exist
        manifest_path = self.split_dir / "manifest.parquet"
        if not manifest_path.exists():
            raise FileNotFoundError(f"Manifest not found: {manifest_path}")
        
        df = pd.read_parquet(manifest_path)
        self.shards = [
            (self.split_dir / row.shard, int(row.num_samples))
            for _, row in df.iterrows()
        ]
        
        # Compute cumulative sum for global indexing
        self.cum = np.cumsum([n for _, n in self.shards])
        
        # Cache for memory-mapped files (per worker)
        self._cache = {}
    
    def __len__(self):
        return int(self.cum[-1])
    
    def _open_npz(self, path):
        """Open npz with memory mapping, cached per worker."""
        wid = torch.utils.data.get_worker_info()
        worker_id = wid.id if wid else -1
        key = (worker_id, str(path))
        
        if key not in self._cache:
            self._cache[key] = np.load(path, mmap_mode="r", allow_pickle=False)
        
        return self._cache[key]
    
    def __getitem__(self, idx):
        # Find which shard contains this index
        shard_idx = int(np.searchsorted(self.cum, idx, side="right"))
        
        # Compute local index within shard
        base = 0 if shard_idx == 0 else int(self.cum[shard_idx - 1])
        local_idx = int(idx - base)
        
        # Open shard and extract sample
        path, _ = self.shards[shard_idx]
        f = self._open_npz(path)
        
        X = torch.from_numpy(f["X"][local_idx]).float()  # [C,H,W]
        y = torch.from_numpy(f["y"][local_idx]).float()  # [H,W]
        
        return X, y.unsqueeze(0)  # [1,H,W] for consistency

# Create dataset
dataset = ShardedIberFireDataset(OUT_DIR / "train")
print("Dataset size:", len(dataset))

# Test one sample
X0, y0 = dataset[0]
print("Sample X shape:", X0.shape, "y shape:", y0.shape)

# Cell 6: create train/val splits and DataLoaders
from torch.utils.data import random_split

N = len(dataset)
n_train = int(0.8 * N)
n_val = N - n_train

train_ds, val_ds = random_split(dataset, [n_train, n_val])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False)

Dataset size: 457
Sample X shape: torch.Size([6, 64, 71]) y shape: torch.Size([1, 64, 71])


In [42]:
# Cell 7: tiny CNN model
class TinyFireCNN(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 1),  # logits
        )

    def forward(self, x):
        return self.net(x)  # [B,1,H,W]

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

in_channels = X0.shape[0]
model = TinyFireCNN(in_channels).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(model)

TinyFireCNN(
  (net): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)


In [43]:
# Cell 8: training loop
def pixel_metrics(logits, y, thr=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()
        y = y.float()
        tp = (preds * y).sum().item()
        fp = (preds * (1 - y)).sum().item()
        fn = ((1 - preds) * y).sum().item()
        tn = (((1 - preds) * (1 - y))).sum().item()
        eps = 1e-8
        prec = tp / (tp + fp + eps)
        rec = tp / (tp + fn + eps)
        f1 = 2 * prec * rec / (prec + rec + eps)
        acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return acc, prec, rec, f1

EPOCHS = 25

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0.0
    print('training on device:', device)
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= max(1, len(train_loader))

    model.eval()
    val_loss = 0.0
    acc_sum = prec_sum = rec_sum = f1_sum = 0.0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = criterion(logits, y)
            val_loss += loss.item()
            a, p, r, f1 = pixel_metrics(logits, y)
            acc_sum += a; prec_sum += p; rec_sum += r; f1_sum += f1
    val_loss /= max(1, len(val_loader))
    acc = acc_sum / max(1, len(val_loader))
    prec = prec_sum / max(1, len(val_loader))
    rec = rec_sum / max(1, len(val_loader))
    f1 = f1_sum / max(1, len(val_loader))

    print(f"Epoch {epoch} | train_loss={train_loss:.4f} "
          f"val_loss={val_loss:.4f} acc={acc:.3f} prec={prec:.3f} rec={rec:.3f} f1={f1:.3f}")

Epoch 1 | train_loss=0.6593 val_loss=0.5962 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 2 | train_loss=0.5094 val_loss=0.4145 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 3 | train_loss=0.3222 val_loss=0.2393 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 4 | train_loss=0.1803 val_loss=0.1345 acc=0.995 prec=0.000 rec=0.000 f1=0.000
Epoch 5 | train_loss=0.1037 val_loss=0.0835 acc=0.995 prec=0.000 rec=0.000 f1=0.000


In [44]:
# Check fire prevalence in the mini dataset
import numpy as np

data = np.load("data/minimal_cnn_samples.npz")
y = data["y"]  # [N,H,W], already binarized

pos = (y == 1).sum()
neg = (y == 0).sum()
print("Total pixels:", y.size, "positives:", pos, "negatives:", neg, "pos_ratio:", pos / y.size)

Total pixels: 2076608 positives: 8787 negatives: 2067821 pos_ratio: 0.004231419699818165
