
# Dataset temporal viewer
Mostra esempi cosi come vengono letti nel training (MedFullBasinDataset + DataLoader). Usa `config/default.yml` per usare gli stessi parametri (image_size, heatmap_stride, temporal_T) e un widget con play per scorrere i 5 frame temporali di ogni campione.


In [1]:

from pathlib import Path
import os
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from cyclone_locator.datasets.med_fullbasin import MedFullBasinDataset

plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams["axes.grid"] = False

def find_repo_root(marker: str = "config/default.yml", max_depth: int = 5) -> Path:
    root = Path.cwd().resolve()
    for _ in range(max_depth):
        if (root / marker).exists():
            return root
        root = root.parent
    raise FileNotFoundError("config/default.yml non trovato; lancia il notebook dalla repo.")

REPO_ROOT = find_repo_root()
CFG = yaml.safe_load(open(REPO_ROOT / "config/default.yml"))
data_cfg = CFG["data"]
train_cfg = CFG["train"]
loss_cfg = CFG["loss"]

manifest_path = REPO_ROOT / data_cfg["manifest_train"]
meta_csv = REPO_ROOT / data_cfg["letterbox_meta_csv"]

print(f"Repo root: {REPO_ROOT}")
print(f"Using manifest: {manifest_path}")
print(f"Letterbox meta: {meta_csv}")
print(
    f"temporal_T={train_cfg.get('temporal_T', 1)}, stride={train_cfg.get('temporal_stride', 1)}, "
    f"image_size={train_cfg['image_size']}, heatmap_stride={train_cfg['heatmap_stride']}"
)


Repo root: /media/fenrir/disk1/danieleda/cyc-firstpass
Using manifest: /media/fenrir/disk1/danieleda/cyc-firstpass/manifests/train.csv
Letterbox meta: /media/fenrir/disk1/danieleda/cyc-firstpass/letterbox_meta.csv
temporal_T=5, stride=6, image_size=384, heatmap_stride=4


In [2]:

# Puoi impostare un manifest alternativo (es. mini_data_input/medicanes_new_windows.csv)
override_manifest = None  # Path("mini_data_input/medicanes_new_windows.csv")

max_samples = 12  # quanti campioni pre-caricare per l'anteprima
preview_batch_size = min(4, int(train_cfg.get("batch_size", 4)))
num_workers = int(train_cfg.get("num_workers", 0))

dataset = MedFullBasinDataset(
    csv_path=str(override_manifest or manifest_path),
    image_size=train_cfg["image_size"],
    heatmap_stride=train_cfg["heatmap_stride"],
    heatmap_sigma_px=loss_cfg["heatmap_sigma_px"],
    use_aug=train_cfg.get("use_aug", False),
    use_pre_letterboxed=data_cfg.get("use_pre_letterboxed", True),
    letterbox_meta_csv=str(meta_csv),
    letterbox_size_assert=data_cfg.get("letterbox_size_assert"),
    temporal_T=train_cfg.get("temporal_T", 1),
    temporal_stride=train_cfg.get("temporal_stride", 1),
)

loader = DataLoader(
    dataset,
    batch_size=preview_batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False,
)

print(
    f"Dataset len: {len(dataset)} | batch_size={preview_batch_size} | num_workers={num_workers}"
)


Dataset len: 16794 | batch_size=4 | num_workers=0


In [3]:

# Prepara qualche campione esattamente come arriva al training
samples = []
stride = train_cfg["heatmap_stride"]

for batch in loader:
    bs = batch["video"].shape[0]
    videos = batch["video"]
    heatmaps = batch["heatmap"]
    presences = batch["presence"]
    paths = batch["image_path"]

    for i in range(bs):
        video_np = videos[i].cpu().numpy()  # (C,T,H,W)
        frames = np.transpose(video_np, (1, 2, 3, 0))  # (T,H,W,C)
        hm_np = heatmaps[i].squeeze(0).cpu().numpy()
        presence = float(presences[i].item())

        cy_idx, cx_idx = np.unravel_index(np.argmax(hm_np), hm_np.shape)
        peak_x_lb = (cx_idx + 0.5) * stride
        peak_y_lb = (cy_idx + 0.5) * stride

        samples.append(
            {
                "frames": frames,
                "heatmap": hm_np,
                "presence": presence,
                "image_path": paths[i] if isinstance(paths, (list, tuple)) else paths,
                "peak_xy_lb": (peak_x_lb, peak_y_lb),
            }
        )
        if len(samples) >= max_samples:
            break
    if len(samples) >= max_samples:
        break

print(f"Prepared {len(samples)} samples from DataLoader.")
if samples:
    print(
        f"Sample shapes -> frames: {samples[0]['frames'].shape} (T,H,W,C), heatmap: {samples[0]['heatmap'].shape}"
    )
else:
    print("Nessun campione caricato: controlla i path nel manifest.")


Prepared 12 samples from DataLoader.
Sample shapes -> frames: (5, 384, 384, 3) (T,H,W,C), heatmap: (96, 96)


In [4]:

import ipywidgets as widgets
import numpy as np
from IPython.display import display

if not samples:
    raise RuntimeError("Nessun campione disponibile da mostrare.")

num_frames = samples[0]["frames"].shape[0]

sample_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(samples) - 1,
    step=1,
    description="sample",
    continuous_update=False,
)
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=num_frames - 1,
    step=1,
    description="frame",
    continuous_update=False,
)
play = widgets.Play(
    value=0,
    min=0,
    max=num_frames - 1,
    step=1,
    interval=500,
    description="play",
    disabled=False,
)
widgets.jslink((play, "value"), (frame_slider, "value"))

controls = widgets.HBox([sample_slider, widgets.VBox([play, frame_slider])])
out = widgets.Output()

def render(_=None):
    sample = samples[sample_slider.value]
    frame_idx = frame_slider.value
    frame = sample["frames"][frame_idx]
    hm = sample["heatmap"]
    presence = sample["presence"]
    peak_x_lb, peak_y_lb = sample["peak_xy_lb"]

    with out:
        out.clear_output(wait=True)
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        axes[0].imshow(np.clip(frame, 0.0, 1.0))
        if presence > 0.5:
            axes[0].scatter(
                [peak_x_lb],
                [peak_y_lb],
                c="lime",
                s=60,
                marker="x",
                label="heatmap peak (lb px)",
            )
            axes[0].legend(loc="lower left")
        axes[0].set_title(
            f"{os.path.basename(sample['image_path'])} | presence={presence:.1f}"
        )
        axes[0].axis("off")

        im_hm = axes[1].imshow(hm, cmap="magma")
        axes[1].set_title("Target heatmap (downsampled)")
        axes[1].axis("off")
        fig.colorbar(im_hm, ax=axes[1], fraction=0.046, pad=0.04)
        plt.tight_layout()
        display(fig)
        plt.close(fig)

for w in (sample_slider, frame_slider, play):
    w.observe(render, names="value")

render()
display(widgets.VBox([controls, out]))


VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='sample', max=11), VBox(â€¦