
# Dataset temporal viewer
Mostra esempi cosi come vengono letti nel training (MedFullBasinDataset + DataLoader). Usa `config/default.yml` per i parametri di default.


In [5]:

from pathlib import Path
import os
import sys

import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams["axes.grid"] = False

def find_repo_root(marker: str = "config/default.yml", max_depth: int = 5) -> Path:
    root = Path.cwd().resolve()
    for _ in range(max_depth):
        if (root / marker).exists():
            return root
        root = root.parent
    raise FileNotFoundError("config/default.yml non trovato; lancia il notebook dalla repo.")

REPO_ROOT = find_repo_root()
sys.path.insert(0, str(REPO_ROOT / "src"))

from cyclone_locator.datasets.med_fullbasin import MedFullBasinDataset

CFG = yaml.safe_load(open(REPO_ROOT / "config/default.yml"))
data_cfg = CFG["data"]
train_cfg = CFG["train"]
loss_cfg = CFG["loss"]

manifest_path = REPO_ROOT / data_cfg["manifest_train"]
meta_csv = REPO_ROOT / data_cfg["letterbox_meta_csv"]

print(f"Repo root: {REPO_ROOT}")
print(f"Using manifest: {manifest_path}")
print(f"Letterbox meta: {meta_csv}")
print(
    f"temporal_T={train_cfg.get('temporal_T', 1)}, stride={train_cfg.get('temporal_stride', 1)}, "
    f"image_size={train_cfg['image_size']}, heatmap_stride={train_cfg['heatmap_stride']}"
)


Repo root: /media/fenrir/disk1/danieleda/cyc-firstpass
Using manifest: /media/fenrir/disk1/danieleda/cyc-firstpass/manifests/train.csv
Letterbox meta: /media/fenrir/disk1/danieleda/cyc-firstpass/manifests/letterbox_meta.csv
temporal_T=13, stride=4, image_size=224, heatmap_stride=4


In [17]:

# Puoi impostare un manifest alternativo (es. mini_data_input/medicanes_new_windows.csv)
override_manifest = None  # Path("mini_data_input/medicanes_new_windows.csv")

# Finestra di campioni da leggere (indice assoluto sul manifest, 0-based)
start_from = 10700  # salta i primi start_from sample
stop_at = 11200     # escludi gli indici >= stop_at; metti None per ignorare
max_samples = 1000 # massimo da accumulare dopo aver raggiunto start_from

preview_batch_size = min(4, int(train_cfg.get("batch_size", 4)))
num_workers = int(train_cfg.get("num_workers", 0))

dataset = MedFullBasinDataset(
    csv_path=str(override_manifest or manifest_path),
    image_size=train_cfg["image_size"],
    heatmap_stride=train_cfg["heatmap_stride"],
    heatmap_sigma_px=loss_cfg["heatmap_sigma_px"],
    use_aug=train_cfg.get("use_aug", False),
    use_pre_letterboxed=data_cfg.get("use_pre_letterboxed", True),
    letterbox_meta_csv=str(meta_csv),
    letterbox_size_assert=data_cfg.get("letterbox_size_assert"),
    temporal_T=train_cfg.get("temporal_T", 1),
    temporal_stride=train_cfg.get("temporal_stride", 1),
)

loader = DataLoader(
    dataset,
    batch_size=preview_batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False,
)

print(
    f"Dataset len: {len(dataset)} | batch_size={preview_batch_size} | num_workers={num_workers}"
)
print(f"Reading window: start_from={start_from}, stop_at={stop_at}, max_samples={max_samples}")


Dataset len: 16794 | batch_size=4 | num_workers=20
Reading window: start_from=10700, stop_at=11200, max_samples=1000


In [18]:

# Prepara qualche campione esattamente come arriva al training
samples = []
stride = train_cfg["heatmap_stride"]
global_idx = 0
finished = False

for batch in loader:
    bs = batch["video"].shape[0]
    videos = batch["video"]
    heatmaps = batch["heatmap"]
    presences = batch["presence"]
    paths = batch["image_path"]
    abs_paths = batch.get("image_path_abs", paths)

    for i in range(bs):
        if stop_at is not None and global_idx >= stop_at:
            finished = True
            break
        if global_idx < start_from:
            global_idx += 1
            continue

        video_np = videos[i].cpu().numpy()  # (C,T,H,W)
        frames = np.transpose(video_np, (1, 2, 3, 0))  # (T,H,W,C)
        hm_np = heatmaps[i].squeeze(0).cpu().numpy()
        target_peak = float(hm_np.max())
        presence = float(presences[i].item())

        center_path = paths[i] if isinstance(paths, (list, tuple)) else paths
        center_abs = abs_paths[i] if isinstance(abs_paths, (list, tuple)) else abs_paths
        window_paths = dataset.temporal_selector.get_window(center_abs)

        cy_idx, cx_idx = np.unravel_index(np.argmax(hm_np), hm_np.shape)
        peak_x_lb = (cx_idx + 0.5) * stride
        peak_y_lb = (cy_idx + 0.5) * stride

        samples.append(
            {
                "frames": frames,
                "heatmap": hm_np,
                "presence": presence,
                "image_path": center_path,
                "window_paths": window_paths,
                "peak_xy_lb": (peak_x_lb, peak_y_lb),
                "target_peak": target_peak,
                "global_idx": global_idx,
            }
        )
        global_idx += 1
        if len(samples) >= max_samples:
            finished = True
            break
    if finished:
        break

print(f"Prepared {len(samples)} samples from DataLoader.")
if samples:
    print(
        f"Sample shapes -> frames: {samples[0]['frames'].shape} (T,H,W,C), heatmap: {samples[0]['heatmap'].shape}"
    )
    print(
        f"First sample global idx: {samples[0]['global_idx']}, last: {samples[-1]['global_idx']}"
    )
else:
    print("Nessun campione caricato: controlla i path nel manifest o la finestra start/stop.")

if samples:
    uniq = sorted({round(s["presence"], 3) for s in samples})
    print(f"Unique presence values (rounded 1e-3): {len(uniq)}")
    print("First 40 values:", uniq[:40])


Prepared 500 samples from DataLoader.
Sample shapes -> frames: (13, 224, 224, 3) (T,H,W,C), heatmap: (56, 56)
First sample global idx: 10700, last: 11199
Unique presence values (rounded 1e-3): 50
First 40 values: [0.0, 0.02, 0.041, 0.061, 0.082, 0.102, 0.122, 0.143, 0.163, 0.184, 0.204, 0.224, 0.245, 0.265, 0.286, 0.306, 0.327, 0.347, 0.367, 0.388, 0.408, 0.429, 0.449, 0.469, 0.49, 0.51, 0.531, 0.551, 0.571, 0.592, 0.612, 0.633, 0.653, 0.673, 0.694, 0.714, 0.735, 0.755, 0.776, 0.796]


In [35]:
# Alternativa senza ipywidgets: animazione JS pre-renderizzata (FuncAnimation)import matplotlib.animation as animationfrom IPython.display import HTMLdef show_sample_animation(sample_idx: int = 0, interval_ms: int = 200, repeat: bool = True):    if not samples:        raise RuntimeError("Nessun campione disponibile da mostrare.")    if not (0 <= sample_idx < len(samples)):        raise ValueError(f"sample_idx deve essere in [0, {len(samples) - 1}]")    sample = samples[sample_idx]    frames = sample["frames"]    hm = sample["heatmap"]    presence = sample["presence"]    peak_x_lb, peak_y_lb = sample["peak_xy_lb"]    target_peak = float(sample.get("target_peak", float(hm.max())))    window_paths = sample.get("window_paths")    frame_names = [os.path.basename(p) for p in window_paths] if window_paths else None    beta = float(loss_cfg.get("heatmap_focal_beta", 4.0) or 4.0)    # Nella HeatmapFocal (variante soft) il termine negativo usa un peso ~(1-target)^(beta+1)    neg_weight = np.power(1.0 - np.clip(hm, 0.0, 1.0), beta + 1.0)    fig, axes = plt.subplots(1, 3, figsize=(18, 5))    # Clip (animata)    frame_im = axes[0].imshow(np.clip(frames[0], 0.0, 1.0))    if target_peak > 1e-3:        axes[0].scatter(            [peak_x_lb],            [peak_y_lb],            c="lime",            s=60,            marker="x",            label="heatmap peak",        )        axes[0].legend(loc="lower left")    initial_name = frame_names[0] if frame_names else os.path.basename(sample["image_path"])    axes[0].set_title(        f"{initial_name} (t=0/{frames.shape[0] - 1}) | presence={presence:.3f} | hm_peak={target_peak:.3f}"    )    axes[0].axis("off")    # Heatmap target (statica)    hm_im = axes[1].imshow(hm, cmap="magma", vmin=0, vmax=1)    axes[1].set_title("Target heatmap (static per clip)")    axes[1].axis("off")    fig.colorbar(hm_im, ax=axes[1], fraction=0.046, pad=0.04)    # Mappa pesi (termine negativo) della HeatmapFocal    w_im = axes[2].imshow(neg_weight, cmap="viridis")    axes[2].set_title(f"HeatmapFocal neg weight (beta={beta:g})")    axes[2].axis("off")    fig.colorbar(w_im, ax=axes[2], fraction=0.046, pad=0.04)    plt.tight_layout()    def update(t):        frame_im.set_array(np.clip(frames[t], 0.0, 1.0))        name = frame_names[t] if frame_names else os.path.basename(sample["image_path"])        axes[0].set_title(            f"{name} (t={t}/{frames.shape[0] - 1}) | presence={presence:.3f} | hm_peak={target_peak:.3f}"        )        return (frame_im,)    anim = animation.FuncAnimation(        fig,        update,        frames=frames.shape[0],        interval=interval_ms,        blit=False,        repeat=repeat,    )    display(HTML(anim.to_jshtml()))    plt.close(fig)# Esempio: cambia sample_idx in base alla finestra caricatashow_sample_animation(sample_idx=270, interval_ms=200)