
# Dataset temporal viewer
Mostra esempi cosi come vengono letti nel training (MedFullBasinDataset + DataLoader). Usa `config/default.yml` per usare gli stessi parametri (image_size, heatmap_stride, temporal_T) e un widget con play per scorrere i 5 frame temporali di ogni campione.


In [1]:

from pathlib import Path
import os
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from cyclone_locator.datasets.med_fullbasin import MedFullBasinDataset

plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams["axes.grid"] = False

def find_repo_root(marker: str = "config/default.yml", max_depth: int = 5) -> Path:
    root = Path.cwd().resolve()
    for _ in range(max_depth):
        if (root / marker).exists():
            return root
        root = root.parent
    raise FileNotFoundError("config/default.yml non trovato; lancia il notebook dalla repo.")

REPO_ROOT = find_repo_root()
CFG = yaml.safe_load(open(REPO_ROOT / "config/default.yml"))
data_cfg = CFG["data"]
train_cfg = CFG["train"]
loss_cfg = CFG["loss"]

manifest_path = REPO_ROOT / data_cfg["manifest_train"]
meta_csv = REPO_ROOT / data_cfg["letterbox_meta_csv"]

print(f"Repo root: {REPO_ROOT}")
print(f"Using manifest: {manifest_path}")
print(f"Letterbox meta: {meta_csv}")
print(
    f"temporal_T={train_cfg.get('temporal_T', 1)}, stride={train_cfg.get('temporal_stride', 1)}, "
    f"image_size={train_cfg['image_size']}, heatmap_stride={train_cfg['heatmap_stride']}"
)


Repo root: /home/daniele/Documenti/Progetti/cyc-firstpass
Using manifest: /home/daniele/Documenti/Progetti/cyc-firstpass/manifests/train.csv
Letterbox meta: /home/daniele/Documenti/Progetti/cyc-firstpass/letterbox_meta.csv
temporal_T=5, stride=6, image_size=384, heatmap_stride=4


In [2]:

# Puoi impostare un manifest alternativo (es. mini_data_input/medicanes_new_windows.csv)
override_manifest = None  # Path("mini_data_input/medicanes_new_windows.csv")

max_samples = 2000  # quanti campioni pre-caricare per l'anteprima
preview_batch_size = min(4, int(train_cfg.get("batch_size", 4)))
num_workers = int(train_cfg.get("num_workers", 0))

dataset = MedFullBasinDataset(
    csv_path=str(override_manifest or manifest_path),
    image_size=train_cfg["image_size"],
    heatmap_stride=train_cfg["heatmap_stride"],
    heatmap_sigma_px=loss_cfg["heatmap_sigma_px"],
    use_aug=train_cfg.get("use_aug", False),
    use_pre_letterboxed=data_cfg.get("use_pre_letterboxed", True),
    letterbox_meta_csv=str(meta_csv),
    letterbox_size_assert=data_cfg.get("letterbox_size_assert"),
    temporal_T=train_cfg.get("temporal_T", 1),
    temporal_stride=train_cfg.get("temporal_stride", 1),
)

loader = DataLoader(
    dataset,
    batch_size=preview_batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False,
)

print(
    f"Dataset len: {len(dataset)} | batch_size={preview_batch_size} | num_workers={num_workers}"
)


Dataset len: 16064 | batch_size=4 | num_workers=0


In [3]:

# Prepara qualche campione esattamente come arriva al training
samples = []
stride = train_cfg["heatmap_stride"]

for batch in loader:
    bs = batch["video"].shape[0]
    videos = batch["video"]
    heatmaps = batch["heatmap"]
    presences = batch["presence"]
    paths = batch["image_path"]
    abs_paths = batch.get("image_path_abs", paths)

    for i in range(bs):
        video_np = videos[i].cpu().numpy()  # (C,T,H,W)
        frames = np.transpose(video_np, (1, 2, 3, 0))  # (T,H,W,C)
        hm_np = heatmaps[i].squeeze(0).cpu().numpy()
        presence = float(presences[i].item())

        center_path = paths[i] if isinstance(paths, (list, tuple)) else paths
        center_abs = abs_paths[i] if isinstance(abs_paths, (list, tuple)) else abs_paths
        window_paths = dataset.temporal_selector.get_window(center_abs)

        cy_idx, cx_idx = np.unravel_index(np.argmax(hm_np), hm_np.shape)
        peak_x_lb = (cx_idx + 0.5) * stride
        peak_y_lb = (cy_idx + 0.5) * stride

        samples.append(
            {
                "frames": frames,
                "heatmap": hm_np,
                "presence": presence,
                "image_path": center_path,
                "window_paths": window_paths,
                "peak_xy_lb": (peak_x_lb, peak_y_lb),
            }
        )
        if len(samples) >= max_samples:
            break
    if len(samples) >= max_samples:
        break

print(f"Prepared {len(samples)} samples from DataLoader.")
if samples:
    print(
        f"Sample shapes -> frames: {samples[0]['frames'].shape} (T,H,W,C), heatmap: {samples[0]['heatmap'].shape}"
    )
else:
    print("Nessun campione caricato: controlla i path nel manifest.")


Prepared 2000 samples from DataLoader.
Sample shapes -> frames: (5, 384, 384, 3) (T,H,W,C), heatmap: (96, 96)


In [8]:

# Alternativa senza ipywidgets: animazione JS pre-renderizzata (FuncAnimation)
import matplotlib.animation as animation
from IPython.display import HTML

def show_sample_animation(sample_idx: int = 0, interval_ms: int = 200, repeat: bool = True):
    if not samples:
        raise RuntimeError("Nessun campione disponibile da mostrare.")
    if not (0 <= sample_idx < len(samples)):
        raise ValueError(f"sample_idx deve essere in [0, {len(samples) - 1}]")

    sample = samples[sample_idx]
    frames = sample["frames"]
    hm = sample["heatmap"]
    presence = sample["presence"]
    peak_x_lb, peak_y_lb = sample["peak_xy_lb"]
    window_paths = sample.get("window_paths")
    frame_names = [os.path.basename(p) for p in window_paths] if window_paths else None

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    frame_im = axes[0].imshow(np.clip(frames[0], 0.0, 1.0))
    if presence > 0.5:
        axes[0].scatter(
            [peak_x_lb],
            [peak_y_lb],
            c="lime",
            s=60,
            marker="x",
            label="heatmap peak (lb px)",
        )
        axes[0].legend(loc="lower left")
    initial_name = frame_names[0] if frame_names else os.path.basename(sample["image_path"])
    axes[0].set_title(
        f"{initial_name} (t=0/{frames.shape[0] - 1}) | presence={presence:.1f}"
    )
    axes[0].axis("off")

    hm_im = axes[1].imshow(hm, cmap="magma")
    axes[1].set_title("Target heatmap (static per clip)")
    axes[1].axis("off")
    fig.colorbar(hm_im, ax=axes[1], fraction=0.046, pad=0.04)
    plt.tight_layout()

    def update(t):
        frame_im.set_array(np.clip(frames[t], 0.0, 1.0))
        name = frame_names[t] if frame_names else os.path.basename(sample["image_path"])
        axes[0].set_title(
            f"{name} (t={t}/{frames.shape[0] - 1}) | presence={presence:.1f}"
        )
        return (frame_im,)

    anim = animation.FuncAnimation(
        fig,
        update,
        frames=frames.shape[0],
        interval=interval_ms,
        blit=False,
        repeat=repeat,
    )
    display(HTML(anim.to_jshtml()))
    plt.close(fig)

# Chiamare show_sample_animation con l'indice del sample e l'intervallo desiderato (ms)
show_sample_animation(sample_idx=1150, interval_ms=200)
