## dataset creation
adapted from Leo's code


In [None]:
from pathlib import Path
import sys
import os
import glob
import random
import json
from collections import Counter, OrderedDict
from datetime import datetime
import numpy as np
import torch
from torch.utils.data import Dataset
repo_root = Path.cwd().parents[0]
sys.path.insert(0, str(repo_root))
sys.path.insert(0, str(repo_root / "code"))
from utils import helper_functions as hf
from collections import Counter, OrderedDict
from datetime import datetime
import json

## Part 1: discovery + split-manifest utilities (single-subject, multi-subject, held-out)


In [None]:
def discover_acquisition_npz(cache_dir, pattern="baseline_*.npz"):
    """Find one cached .npz per acquisition in one cache dir."""
    paths = sorted(glob.glob(os.path.join(str(cache_dir), pattern)))
    if not paths:
        raise FileNotFoundError(f"No acquisition caches found in {cache_dir} with pattern {pattern}")
    return [str(p) for p in paths]
def discover_acquisitions_multi(subject_to_cache_dir, pattern="baseline_*.npz"):
    """
    Discover acquisitions across subjects.
    Returns deterministic list: [{"path": str, "subject": str}, ...]
    """
    if not isinstance(subject_to_cache_dir, dict) or len(subject_to_cache_dir) == 0:
        raise ValueError("subject_to_cache_dir must be a non-empty dict")
    out = []
    for subject in sorted(subject_to_cache_dir.keys()):
        for p in discover_acquisition_npz(subject_to_cache_dir[subject], pattern=pattern):
            out.append({"path": p, "subject": str(subject)})
    return out
def _read_num_frames(npz_path, frames_key="frames"):
    with np.load(npz_path, allow_pickle=True) as z:
        if frames_key not in z:
            raise KeyError(f"{npz_path} missing key '{frames_key}'")
        x = z[frames_key]
        if x.ndim not in (3, 4):
            raise ValueError(f"{npz_path} '{frames_key}' must be [T,H,W] or [T,C,H,W], got {x.shape}")
        return int(x.shape[0])
def _count_by_subject(paths, path_to_subject):
    return Counter(path_to_subject.get(p, "UNK") for p in paths)
def _print_split_summary(train_paths, test_paths, path_to_subject, frames_key="frames"):
    train_frames = sum(_read_num_frames(p, frames_key=frames_key) for p in train_paths)
    test_frames = sum(_read_num_frames(p, frames_key=frames_key) for p in test_paths)
    tr_cnt = _count_by_subject(train_paths, path_to_subject)
    te_cnt = _count_by_subject(test_paths, path_to_subject)
    tr_txt = ", ".join([f"{m}:{tr_cnt[m]}" for m in sorted(tr_cnt.keys())])
    te_txt = ", ".join([f"{m}:{te_cnt[m]}" for m in sorted(te_cnt.keys())])
    print(f"Train acquisitions: {len(train_paths)} | total frames: {train_frames} | per-subject=({tr_txt})")
    print(f"Test acquisitions:  {len(test_paths)} | total frames: {test_frames} | per-subject=({te_txt})")
def _write_manifest(manifest, manifest_path):
    os.makedirs(os.path.dirname(str(manifest_path)), exist_ok=True)
    with open(manifest_path, "w", encoding="utf-8") as f:
        json.dump(manifest, f, indent=2)
    print(f"Manifest saved: {manifest_path}")
    return str(manifest_path)
def create_split_manifest(
    cache_dir,
    split_ratio=0.8,
    seed=42,
    manifest_path=None,
    pattern="baseline_*.npz",
    frames_key="frames",
):
    """Single-cache deterministic acquisition-wise split."""
    if not (0.0 < split_ratio < 1.0):
        raise ValueError("split_ratio must be between 0 and 1")
    paths = discover_acquisition_npz(cache_dir, pattern=pattern)
    rng = random.Random(seed)
    shuffled = paths.copy()
    rng.shuffle(shuffled)
    n_total = len(shuffled)
    n_train = int(split_ratio * n_total)
    n_train = max(1, min(n_total - 1, n_train)) if n_total > 1 else 1
    train_paths = sorted(shuffled[:n_train])
    test_paths = sorted(shuffled[n_train:])
    manifest = {
        "train": train_paths,
        "test": test_paths,
        "seed": int(seed),
        "split_ratio": float(split_ratio),
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "cache_dir": str(cache_dir),
        "pattern": pattern,
        "frames_key": frames_key,
    }
    if manifest_path is None:
        manifest_path = os.path.join(str(cache_dir), "splits.json")
    out = _write_manifest(manifest, manifest_path)
    path_to_subject = {p: Path(p).parent.name for p in train_paths + test_paths}
    _print_split_summary(train_paths, test_paths, path_to_subject, frames_key=frames_key)
    return out
def create_split_manifest_multi(
    subject_to_cache_dir,
    split_ratio=0.8,
    seed=42,
    pattern="baseline_*.npz",
    frames_key="frames",
    manifest_path=None,
):
    """Combined split across all subjects (acquisition-wise random, subject metadata preserved)."""
    if not (0.0 < split_ratio < 1.0):
        raise ValueError("split_ratio must be between 0 and 1")
    acqs = discover_acquisitions_multi(subject_to_cache_dir, pattern=pattern)
    rng = random.Random(seed)
    shuffled = acqs.copy()
    rng.shuffle(shuffled)
    n_total = len(shuffled)
    n_train = int(split_ratio * n_total)
    n_train = max(1, min(n_total - 1, n_train)) if n_total > 1 else 1
    train_acqs = sorted(shuffled[:n_train], key=lambda d: d["path"])
    test_acqs = sorted(shuffled[n_train:], key=lambda d: d["path"])
    train_paths = [d["path"] for d in train_acqs]
    test_paths = [d["path"] for d in test_acqs]
    path_to_subject = {d["path"]: d["subject"] for d in acqs}
    manifest = {
        "train": train_paths,
        "test": test_paths,
        "meta": {p: {"subject": m} for p, m in path_to_subject.items()},
        "acquisitions": acqs,
        "seed": int(seed),
        "split_ratio": float(split_ratio),
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "subject_to_cache_dir": {str(k): str(v) for k, v in subject_to_cache_dir.items()},
        "pattern": pattern,
        "frames_key": frames_key,
    }
    if manifest_path is None:
        first_dir = str(subject_to_cache_dir[sorted(subject_to_cache_dir.keys())[0]])
        manifest_path = os.path.join(first_dir, "splits_multi.json")
    out = _write_manifest(manifest, manifest_path)
    _print_split_summary(train_paths, test_paths, path_to_subject, frames_key=frames_key)
    return out
def create_subject_heldout_manifests(
    subject_to_cache_dir,
    pattern="baseline_*.npz",
    frames_key="frames",
    output_dir=None,
):
    """
    2-subject held-out split manifests:
    - <A>_train_<B>_test.json
    - <B>_train_<A>_test.json
    """
    if not isinstance(subject_to_cache_dir, dict) or len(subject_to_cache_dir) != 2:
        raise ValueError("subject_to_cache_dir must contain exactly 2 subjects")
    a, b = sorted(subject_to_cache_dir.keys())
    a_acqs = discover_acquisitions_multi({a: subject_to_cache_dir[a]}, pattern=pattern)
    b_acqs = discover_acquisitions_multi({b: subject_to_cache_dir[b]}, pattern=pattern)
    a_paths = sorted([d["path"] for d in a_acqs])
    b_paths = sorted([d["path"] for d in b_acqs])
    path_to_subject = {d["path"]: d["subject"] for d in (a_acqs + b_acqs)}
    meta = {p: {"subject": m} for p, m in path_to_subject.items()}
    if output_dir is None:
        output_dir = str(subject_to_cache_dir[a])
    os.makedirs(output_dir, exist_ok=True)
    m1_path = os.path.join(output_dir, f"{a}_train_{b}_test.json")
    m2_path = os.path.join(output_dir, f"{b}_train_{a}_test.json")
    m1 = {
        "train": a_paths,
        "test": b_paths,
        "meta": meta,
        "acquisitions": a_acqs + b_acqs,
        "heldout_mode": f"train={a},test={b}",
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "pattern": pattern,
        "frames_key": frames_key,
    }
    m2 = {
        "train": b_paths,
        "test": a_paths,
        "meta": meta,
        "acquisitions": a_acqs + b_acqs,
        "heldout_mode": f"train={b},test={a}",
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "pattern": pattern,
        "frames_key": frames_key,
    }
    out1 = _write_manifest(m1, m1_path)
    print(f"Summary for {Path(out1).name}")
    _print_split_summary(a_paths, b_paths, path_to_subject, frames_key=frames_key)
    out2 = _write_manifest(m2, m2_path)
    print(f"Summary for {Path(out2).name}")
    _print_split_summary(b_paths, a_paths, path_to_subject, frames_key=frames_key)
    return out1, out2


## per-acquisition forecasting Dataset


In [19]:
class FUSForecastWindowDataset(Dataset):
    """
    Per-acquisition lazy sliding-window forecasting dataset.
    Returns:
      - context: [W, C, H, W]
      - target:  [K, C, H, W]
    If return_meta=True:
      - (context, target, {"subject", "path", "end_ctx"})
    """
    def __init__(
        self,
        manifest_path=None,
        acq_paths=None,
        split=None,
        window_size=8,
        pred_horizon=1,
        stride=1,
        frames_key="frames",
        labels_key=None,
        mask_key=None,
        exclude_label=-1,
        lru_cache_size=2,
        target_size=112,
        return_meta=False,
    ):
        if window_size <= 0 or pred_horizon <= 0 or stride <= 0:
            raise ValueError("window_size, pred_horizon, and stride must all be > 0")
        if manifest_path is None and acq_paths is None:
            raise ValueError("Provide either manifest_path or acq_paths")
        self.window_size = int(window_size)
        self.pred_horizon = int(pred_horizon)
        self.stride = int(stride)
        self.frames_key = frames_key
        self.labels_key = labels_key
        self.mask_key = mask_key
        self.exclude_label = exclude_label
        self.lru_cache_size = int(max(1, lru_cache_size))
        self.target_size = int(target_size)
        self.return_meta = bool(return_meta)
        self.acq_paths, self.acq_subjects = self._resolve_inputs(manifest_path, acq_paths, split)
        for p in self.acq_paths:
            if not os.path.exists(p):
                raise FileNotFoundError(f"Missing acquisition file: {p}")
        self._acq_cache = OrderedDict()
        self.index_map = []
        self.acq_meta = []
        self.expected_chw = None
        self._build_index_map()
        if len(self.index_map) == 0:
            raise RuntimeError("No valid windows found. Check window/pred/stride and exclusion rules.")
    def _resolve_inputs(self, manifest_path, acq_paths, split):
        path_to_subject = {}
        if manifest_path is not None:
            with open(manifest_path, "r", encoding="utf-8") as f:
                m = json.load(f)
            # Optional metadata map: meta[path] -> {subject: ...}
            meta = m.get("meta", {})
            if isinstance(meta, dict):
                for p, info in meta.items():
                    if isinstance(info, dict) and info.get("subject") is not None:
                        path_to_subject[str(p)] = str(info["subject"])
            # Optional acquisitions list: [{path, subject}, ...]
            acqs = m.get("acquisitions", [])
            if isinstance(acqs, list):
                for item in acqs:
                    if isinstance(item, dict) and "path" in item and "subject" in item:
                        path_to_subject[str(item["path"])] = str(item["subject"])
            if split is None:
                paths = list(m.get("train", [])) + list(m.get("test", []))
            else:
                if split not in ("train", "test"):
                    raise ValueError("split must be 'train' or 'test'")
                paths = list(m.get(split, []))
        else:
            paths = list(acq_paths)
        if not paths:
            raise ValueError("No acquisition files provided")
        paths = [str(p) for p in paths]
        subjects = [path_to_subject.get(p, Path(p).parent.name) for p in paths]
        return paths, subjects
    def _get_invalid_mask(self, z, Tn, path):
        invalid = np.zeros(Tn, dtype=bool)
        if self.labels_key is not None and self.labels_key in z:
            labels = np.asarray(z[self.labels_key]).squeeze()
            if labels.shape[0] != Tn:
                raise ValueError(f"{path}: labels length {labels.shape[0]} != T {Tn}")
            invalid |= (labels == self.exclude_label)
        if self.mask_key is not None and self.mask_key in z:
            mask = np.asarray(z[self.mask_key]).squeeze()
            if mask.shape[0] != Tn:
                raise ValueError(f"{path}: mask length {mask.shape[0]} != T {Tn}")
            invalid |= mask.astype(bool)
        return invalid
    def _normalize_frames(self, frames, path):
        if frames.ndim == 3:
            frames = frames[:, np.newaxis, :, :]
        if frames.ndim != 4:
            raise ValueError(f"{path}: frames must be [T,H,W] or [T,C,H,W], got {frames.shape}")
        if frames.shape[1] != 1:
            raise ValueError(f"{path}: np_pad_or_crop_to_square expects channel=1, got C={frames.shape[1]}")
        frames = hf.np_pad_or_crop_to_square(frames, target_size=self.target_size).astype(np.float32, copy=False)
        return frames
    def _build_index_map(self):
        for acq_idx, path in enumerate(self.acq_paths):
            with np.load(path, allow_pickle=True) as z:
                if self.frames_key not in z:
                    raise KeyError(f"{path} missing frames key '{self.frames_key}'")
                frames = self._normalize_frames(z[self.frames_key], path)
                Tn, Cn, Hn, Wn = frames.shape
                if self.expected_chw is None:
                    self.expected_chw = (Cn, Hn, Wn)
                elif self.expected_chw != (Cn, Hn, Wn):
                    raise ValueError(
                        f"Inconsistent frame shape across acquisitions: expected {self.expected_chw}, got {(Cn, Hn, Wn)} in {path}"
                    )
                if Tn < self.window_size + self.pred_horizon:
                    self.acq_meta.append({"T": Tn, "shape": (Cn, Hn, Wn), "n_windows": 0})
                    continue
                invalid = self._get_invalid_mask(z, Tn, path)
                n_windows = 0
                for end_ctx in range(self.window_size - 1, Tn - self.pred_horizon, self.stride):
                    start_ctx = end_ctx - self.window_size + 1
                    target_end = end_ctx + self.pred_horizon
                    if start_ctx < 0 or target_end >= Tn:
                        continue
                    if invalid[start_ctx: target_end + 1].any():
                        continue
                    self.index_map.append((acq_idx, end_ctx))
                    n_windows += 1
                self.acq_meta.append({"T": Tn, "shape": (Cn, Hn, Wn), "n_windows": n_windows})
    def __len__(self):
        return len(self.index_map)
    def _load_acquisition(self, acq_idx):
        if acq_idx in self._acq_cache:
            x = self._acq_cache.pop(acq_idx)
            self._acq_cache[acq_idx] = x
            return x
        path = self.acq_paths[acq_idx]
        with np.load(path, allow_pickle=True) as z:
            frames = self._normalize_frames(z[self.frames_key], path)
        _, Cn, Hn, Wn = frames.shape
        if self.expected_chw != (Cn, Hn, Wn):
            raise ValueError(f"{path}: shape changed vs expected {self.expected_chw}, got {(Cn, Hn, Wn)}")
        self._acq_cache[acq_idx] = frames
        if len(self._acq_cache) > self.lru_cache_size:
            self._acq_cache.popitem(last=False)
        return frames
    def __getitem__(self, idx):
        acq_idx, end_ctx = self.index_map[idx]
        frames = self._load_acquisition(acq_idx)
        start_ctx = end_ctx - self.window_size + 1
        target_start = end_ctx + 1
        target_end = end_ctx + self.pred_horizon
        if start_ctx < 0 or target_end >= frames.shape[0]:
            raise IndexError("Computed window is out of bounds")
        context = frames[start_ctx:end_ctx + 1]            # [W, C, H, W]
        target = frames[target_start:target_end + 1]      # [K, C, H, W]
        if context.shape[0] != self.window_size:
            raise RuntimeError(f"Context length mismatch: expected {self.window_size}, got {context.shape[0]}")
        if target.shape[0] != self.pred_horizon:
            raise RuntimeError(f"Target length mismatch: expected {self.pred_horizon}, got {target.shape[0]}")
        x = torch.from_numpy(context)
        y = torch.from_numpy(target)
        if x.dtype != torch.float32 or y.dtype != torch.float32:
            raise TypeError("Dataset outputs must be float32 tensors")
        if self.return_meta:
            info = {
                "subject": self.acq_subjects[acq_idx],
                "path": self.acq_paths[acq_idx],
                "end_ctx": int(end_ctx),
            }
            return x, y, info
        return x, y


## use it


In [22]:
subject_to_cache_dir = {
    "A": repo_root / "derivatives" / "preprocessing" / "secundo" / "baseline_only",
    "B": repo_root / "derivatives" / "preprocessing" / "gus" / "baseline_only",
}
combined_manifest = create_split_manifest_multi(
    subject_to_cache_dir=subject_to_cache_dir,
    split_ratio=0.8,
    seed=42,
    pattern="baseline_*.npz",
    manifest_path=repo_root / "derivatives" / "preprocessing" / "splits_multi.json",
)
ds_train_combined = FUSForecastWindowDataset(
    manifest_path=combined_manifest,
    split="train",
    window_size=8,
    pred_horizon=2,
    stride=1,
)
ds_test_combined = FUSForecastWindowDataset(
    manifest_path=combined_manifest,
    split="test",
    window_size=8,
    pred_horizon=2,
    stride=1,
)
print("combined lens:", len(ds_train_combined), len(ds_test_combined))
x, y = ds_train_combined[0]
print("combined sample:", tuple(x.shape), tuple(y.shape))


Manifest saved: c:\Users\ESPCI\Documents\GitHub\fUSPredict\derivatives\preprocessing\splits_multi.json
Train acquisitions: 37 | total frames: 15518 | per-subject=(A:26, B:11)
Test acquisitions:  10 | total frames: 4363 | per-subject=(A:8, B:2)
Function pad shape:  (534, 1, 112, 112)
Function pad shape:  (549, 1, 112, 112)
Function pad shape:  (507, 1, 112, 112)
Function pad shape:  (268, 1, 112, 112)
Function pad shape:  (516, 1, 112, 112)
Function pad shape:  (539, 1, 112, 112)
Function pad shape:  (530, 1, 112, 112)
Function pad shape:  (528, 1, 112, 112)
Function pad shape:  (264, 1, 112, 112)
Function pad shape:  (269, 1, 112, 112)
Function pad shape:  (258, 1, 112, 112)
Function pad shape:  (526, 1, 112, 112)
Function pad shape:  (502, 1, 112, 112)
Function pad shape:  (538, 1, 112, 112)
Function pad shape:  (537, 1, 112, 112)
Function pad shape:  (252, 1, 112, 112)
Function pad shape:  (518, 1, 112, 112)
Function pad shape:  (509, 1, 112, 112)
Function pad shape:  (213, 1, 112, 1