In [None]:
# Setup
import sys
from pathlib import Path

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from src.data.datasets import CNNIberFireDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

ZARR_PATH = project_root / "data" / "silver" / "IberFire.zarr"
print(f"Project root: {project_root}")
print(f"Zarr path exists: {ZARR_PATH.exists()}")

In [None]:
from src.data.datasets import SimpleIberFireSegmentationDataset
from torch.utils.data import DataLoader

feature_vars = [
    "wind_speed_mean",
    "t2m_mean",
    "RH_mean",
    "total_precipitation_mean",
]

train_ds = SimpleIberFireSegmentationDataset(
    zarr_path=ZARR_PATH,
    time_start="2018-01-01",
    time_end="2020-12-31",
    feature_vars=feature_vars,
    label_var="is_near_fire",
    spatial_downsample=4,
    lead_time=1,         # predict tomorrow
    compute_stats=True,  # or precompute & pass stats
)

train_loader = DataLoader(
    train_ds,
    batch_size=2,        # start small, check memory
    shuffle=True,
    num_workers=0,
)

In [None]:
X, y = train_ds[0]
print(f"Feature tensor shape: {X.shape}, Label tensor shape: {y.shape}")
print(train_ds.get_time_value(0))

In [None]:
# testing new functionality 
# import xarray as xr
# import numpy as np
# from pathlib import Path

# zarr_path = Path("data/silver/IberFire.zarr")

# ds = xr.open_zarr(zarr_path, consolidated=True)

# label_var = "is_near_fire"

# # Boolean: for each day, did fire occur anywhere?
# fire_any = ds[label_var].any(dim=("y", "x"))  # -> (time,)

# # Extract indices
# fire_days_idx = np.where(fire_any.values)[0]
# no_fire_days_idx = np.where(~fire_any.values)[0]

# print("Total days:", ds.dims["time"])
# print("Days with fire   :", len(fire_days_idx))


In [None]:
# print("Days with no fire:", len(no_fire_days_idx))