In [2]:
# Setup
import sys
from pathlib import Path

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.data.datasets import CNNIberFireDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

ZARR_PATH = project_root / "data" / "silver" / "IberFire.zarr"
print(f"Project root: {project_root}")
print(f"Zarr path exists: {ZARR_PATH.exists()}")

Project root: /Users/vladimir/catalonia-wildfire-prediction
Zarr path exists: True


In [3]:
# Create dataset
dataset = CNNIberFireDataset(
    zarr_path=str(ZARR_PATH),
    time_start="2018-01-01",
    time_end="2020-12-31",
    feature_vars=[
        "wind_speed_mean",
        "t2m_mean",
        "RH_mean",
        "total_precipitation_mean",
    ],
    label_var="is_near_fire",
    spatial_downsample=4,
    task="tile_classification",
    sample_strategy="stratified",
    fire_oversample_ratio=3.0,
)

print(f"\n✅ Dataset created: {len(dataset)} samples")
print(f"Recommended pos_weight: {dataset.get_pos_weight():.2f}")

Opening Zarr dataset: /Users/vladimir/catalonia-wildfire-prediction/data/silver/IberFire.zarr
Filtering time range: 2018-01-01 to 2020-12-31


UFuncTypeError: ufunc 'greater_equal' did not contain a loop with signature matching types (<class 'numpy.dtypes.Int64DType'>, <class 'numpy.dtypes.DateTime64DType'>) -> None

In [None]:
# Test single sample
X, y = dataset[0]
print(f"X shape: {X.shape} (C, H, W)")
print(f"y shape: {y.shape}")
print(f"y value: {y.item()} (1=fire, 0=no fire)")
print(f"X range: [{X.min():.2f}, {X.max():.2f}]")

In [None]:
# Test DataLoader
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
X_batch, y_batch = next(iter(loader))

print(f"Batch X shape: {X_batch.shape} (B, C, H, W)")
print(f"Batch y shape: {y_batch.shape}")
print(f"Fire tiles in batch: {y_batch.sum().item():.0f}/8")

In [None]:
# Visualize a fire sample
fire_idx = dataset.fire_indices[0]
X, y = dataset[fire_idx]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle(f"Fire Sample | Label: {y.item()}", fontsize=14)

for i, (ax, var) in enumerate(zip(axes, dataset.feature_vars)):
    im = ax.imshow(X[i].numpy(), cmap='viridis')
    ax.set_title(var)
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()

In [None]:
# Save normalization stats
stats_path = project_root / "data" / "processed" / "stats.json"
dataset.save_stats(str(stats_path))
print(f"✅ Saved stats to {stats_path}")

In [None]:
# Check class distribution in first 100 samples
fire_count = sum(dataset[i][1].item() for i in range(min(100, len(dataset))))
print(f"Fire samples in first 100: {fire_count}/100 ({fire_count}%)")

In [None]:
# Visualize a no-fire sample for comparison
no_fire_idx = dataset.no_fire_indices[0]
X, y = dataset[no_fire_idx]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.suptitle(f"No-Fire Sample | Label: {y.item()}", fontsize=14)

for i, (ax, var) in enumerate(zip(axes, dataset.feature_vars)):
    im = ax.imshow(X[i].numpy(), cmap='viridis')
    ax.set_title(var)
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()