## dataset creation
adapted from Leo's code


In [None]:
import glob
import json
import os
import random
import sys
from collections import Counter
from datetime import datetime
from pathlib import Path

# Third-party
import numpy as np

# Path setup 
repo_root = Path.cwd().parents[0]
sys.path.insert(0, str(repo_root))
sys.path.insert(0, str(repo_root / "code"))

In [None]:
from utils.datasets import FUSForecastWindowDataset

## find files + split-manifest utilities (single-subject, per-subject, multi-subject, held-out)

the cell below has functions to find the cached baseline files and create different test/train splits. You can either split per subject, by combining all acquisitions for all subjects, or by training on one subject and testing on another to test for generalization. 


In [None]:
from utils.manifest_utils import (
    assert_manifest_paths_exist,
    normalize_manifest_path,
    normalize_manifest_payload,
)


def discover_acquisition_npz(cache_dir, pattern="baseline_*.npz"):
    """Find one cached .npz per acquisition in one cache dir."""
    paths = sorted(glob.glob(os.path.join(str(cache_dir), pattern)))
    if not paths:
        raise FileNotFoundError(
            f"No acquisition caches found in {cache_dir} with pattern {pattern}"
        )
    return [normalize_manifest_path(p) for p in paths]


def discover_acquisitions_multi(subject_to_cache_dir, pattern="baseline_*.npz"):
    """
    find acquisitions across subjects.
    Returns deterministic list: [{"path": str, "subject": str}, ...]
    """
    if not isinstance(subject_to_cache_dir, dict) or len(subject_to_cache_dir) == 0:
        raise ValueError("subject_to_cache_dir must be a non-empty dict")
    out = []
    for subject in sorted(subject_to_cache_dir.keys()):
        for p in discover_acquisition_npz(
            subject_to_cache_dir[subject], pattern=pattern
        ):
            out.append({"path": p, "subject": str(subject)})
    return out


def _read_num_frames(npz_path, frames_key="frames"):
    with np.load(npz_path, allow_pickle=True) as z:
        if frames_key not in z:
            raise KeyError(f"{npz_path} missing key '{frames_key}'")
        x = z[frames_key]
        if x.ndim not in (3, 4):
            raise ValueError(
                f"{npz_path} '{frames_key}' must be [T,H,W] or [T,C,H,W], got {x.shape}"
            )
        return int(x.shape[0])


def _count_by_subject(paths, path_to_subject):
    return Counter(path_to_subject.get(p, "UNK") for p in paths)


def _print_split_summary(train_paths, test_paths, path_to_subject, frames_key="frames"):
    train_frames = sum(_read_num_frames(p, frames_key=frames_key) for p in train_paths)
    test_frames = sum(_read_num_frames(p, frames_key=frames_key) for p in test_paths)
    tr_cnt = _count_by_subject(train_paths, path_to_subject)
    te_cnt = _count_by_subject(test_paths, path_to_subject)
    tr_txt = ", ".join([f"{m}:{tr_cnt[m]}" for m in sorted(tr_cnt.keys())])
    te_txt = ", ".join([f"{m}:{te_cnt[m]}" for m in sorted(te_cnt.keys())])
    print(
        f"Train acquisitions: {len(train_paths)} | total frames: {train_frames} | per-subject=({tr_txt})"
    )
    print(
        f"Test acquisitions:  {len(test_paths)} | total frames: {test_frames} | per-subject=({te_txt})"
    )


def _write_manifest(manifest, manifest_path):
    manifest_out = normalize_manifest_payload(manifest)
    manifest_path = Path(manifest_path).expanduser().resolve(strict=False)
    manifest_path.parent.mkdir(parents=True, exist_ok=True)
    with manifest_path.open("w", encoding="utf-8") as f:
        json.dump(manifest_out, f, indent=2)
    n_checked = assert_manifest_paths_exist(manifest_path)
    print(
        f"Manifest saved: {manifest_path.as_posix()} "
        f"(validated {n_checked} manifest path entries)"
    )
    return manifest_path.as_posix()


def create_split_manifest(
    cache_dir,
    split_ratio=0.8,
    seed=42,
    manifest_path=None,
    pattern="baseline_*.npz",
    frames_key="frames",
):
    """Single-cache deterministic acquisition-wise split."""
    if not (0.0 < split_ratio < 1.0):
        raise ValueError("split_ratio must be between 0 and 1")
    paths = discover_acquisition_npz(cache_dir, pattern=pattern)
    rng = random.Random(seed)
    shuffled = paths.copy()
    rng.shuffle(shuffled)
    n_total = len(shuffled)
    n_train = int(split_ratio * n_total)
    n_train = max(1, min(n_total - 1, n_train)) if n_total > 1 else 1
    train_paths = sorted(shuffled[:n_train])
    test_paths = sorted(shuffled[n_train:])
    manifest = {
        "train": train_paths,
        "test": test_paths,
        "seed": int(seed),
        "split_ratio": float(split_ratio),
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "cache_dir": str(cache_dir),
        "pattern": pattern,
        "frames_key": frames_key,
    }
    if manifest_path is None:
        manifest_path = os.path.join(str(cache_dir), "splits.json")
    out = _write_manifest(manifest, manifest_path)
    path_to_subject = {p: Path(p).parent.name for p in train_paths + test_paths}
    _print_split_summary(
        train_paths, test_paths, path_to_subject, frames_key=frames_key
    )
    return out


def create_split_manifest_multi(
    subject_to_cache_dir,
    split_ratio=0.8,
    seed=42,
    pattern="baseline_*.npz",
    frames_key="frames",
    manifest_path=None,
):
    """Combined split across all subjects (acquisition-wise random, subject metadata preserved)."""
    if not (0.0 < split_ratio < 1.0):
        raise ValueError("split_ratio must be between 0 and 1")
    acqs = discover_acquisitions_multi(subject_to_cache_dir, pattern=pattern)
    rng = random.Random(seed)
    shuffled = acqs.copy()
    rng.shuffle(shuffled)
    n_total = len(shuffled)
    n_train = int(split_ratio * n_total)
    n_train = max(1, min(n_total - 1, n_train)) if n_total > 1 else 1
    train_acqs = sorted(shuffled[:n_train], key=lambda d: d["path"])
    test_acqs = sorted(shuffled[n_train:], key=lambda d: d["path"])
    train_paths = [d["path"] for d in train_acqs]
    test_paths = [d["path"] for d in test_acqs]
    path_to_subject = {d["path"]: d["subject"] for d in acqs}
    manifest = {
        "train": train_paths,
        "test": test_paths,
        "meta": {p: {"subject": m} for p, m in path_to_subject.items()},
        "acquisitions": acqs,
        "seed": int(seed),
        "split_ratio": float(split_ratio),
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "subject_to_cache_dir": {
            str(k): str(v) for k, v in subject_to_cache_dir.items()
        },
        "pattern": pattern,
        "frames_key": frames_key,
    }
    if manifest_path is None:
        first_dir = str(subject_to_cache_dir[sorted(subject_to_cache_dir.keys())[0]])
        manifest_path = os.path.join(first_dir, "splits_multi.json")
    out = _write_manifest(manifest, manifest_path)
    _print_split_summary(
        train_paths, test_paths, path_to_subject, frames_key=frames_key
    )
    return out


def create_split_manifests_per_subject(
    subject_to_cache_dir,
    split_ratio=0.8,
    seed=42,
    pattern="baseline_*.npz",
    frames_key="frames",
    output_dir=None,
):
    """Per-subject deterministic acquisition-wise split. Writes one manifest per subject."""
    if not isinstance(subject_to_cache_dir, dict) or len(subject_to_cache_dir) == 0:
        raise ValueError("subject_to_cache_dir must be a non-empty dict")
    if not (0.0 < split_ratio < 1.0):
        raise ValueError("split_ratio must be between 0 and 1")
    outputs = {}
    for idx, subject in enumerate(sorted(subject_to_cache_dir.keys())):
        cache_dir = subject_to_cache_dir[subject]
        paths = discover_acquisition_npz(cache_dir, pattern=pattern)
        rng = random.Random(int(seed) + idx)
        shuffled = paths.copy()
        rng.shuffle(shuffled)
        n_total = len(shuffled)
        n_train = int(split_ratio * n_total)
        n_train = max(1, min(n_total - 1, n_train)) if n_total > 1 else 1
        train_paths = sorted(shuffled[:n_train])
        test_paths = sorted(shuffled[n_train:])
        subject_str = str(subject)
        path_to_subject = {p: subject_str for p in paths}
        manifest = {
            "train": train_paths,
            "test": test_paths,
            "meta": {p: {"subject": subject_str} for p in paths},
            "acquisitions": [{"path": p, "subject": subject_str} for p in paths],
            "seed": int(seed) + idx,
            "split_ratio": float(split_ratio),
            "timestamp": datetime.now().isoformat(timespec="seconds"),
            "cache_dir": str(cache_dir),
            "subject": subject_str,
            "pattern": pattern,
            "frames_key": frames_key,
        }
        if output_dir is None:
            manifest_path = os.path.join(str(cache_dir), f"splits_{subject_str}.json")
        else:
            os.makedirs(output_dir, exist_ok=True)
            manifest_path = os.path.join(str(output_dir), f"splits_{subject_str}.json")
        out = _write_manifest(manifest, manifest_path)
        print(f"Summary for {subject_str} -> {Path(out).name}")
        _print_split_summary(
            train_paths, test_paths, path_to_subject, frames_key=frames_key
        )
        outputs[subject_str] = out
    return outputs


def create_subject_heldout_manifests(
    subject_to_cache_dir,
    pattern="baseline_*.npz",
    frames_key="frames",
    output_dir=None,
):
    """
    2-subject held-out split manifests:
    - <A>_train_<B>_test.json
    - <B>_train_<A>_test.json
    """
    if not isinstance(subject_to_cache_dir, dict) or len(subject_to_cache_dir) != 2:
        raise ValueError("subject_to_cache_dir must contain exactly 2 subjects")
    a, b = sorted(subject_to_cache_dir.keys())
    a_acqs = discover_acquisitions_multi({a: subject_to_cache_dir[a]}, pattern=pattern)
    b_acqs = discover_acquisitions_multi({b: subject_to_cache_dir[b]}, pattern=pattern)
    a_paths = sorted([d["path"] for d in a_acqs])
    b_paths = sorted([d["path"] for d in b_acqs])
    path_to_subject = {d["path"]: d["subject"] for d in (a_acqs + b_acqs)}
    meta = {p: {"subject": m} for p, m in path_to_subject.items()}
    if output_dir is None:
        output_dir = str(subject_to_cache_dir[a])
    os.makedirs(output_dir, exist_ok=True)
    m1_path = os.path.join(output_dir, f"{a}_train_{b}_test.json")
    m2_path = os.path.join(output_dir, f"{b}_train_{a}_test.json")
    m1 = {
        "train": a_paths,
        "test": b_paths,
        "meta": meta,
        "acquisitions": a_acqs + b_acqs,
        "heldout_mode": f"train={a},test={b}",
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "pattern": pattern,
        "frames_key": frames_key,
    }
    m2 = {
        "train": b_paths,
        "test": a_paths,
        "meta": meta,
        "acquisitions": a_acqs + b_acqs,
        "heldout_mode": f"train={b},test={a}",
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "pattern": pattern,
        "frames_key": frames_key,
    }
    out1 = _write_manifest(m1, m1_path)
    print(f"Summary for {Path(out1).name}")
    _print_split_summary(a_paths, b_paths, path_to_subject, frames_key=frames_key)
    out2 = _write_manifest(m2, m2_path)
    print(f"Summary for {Path(out2).name}")
    _print_split_summary(b_paths, a_paths, path_to_subject, frames_key=frames_key)
    return out1, out2



## use it


In [None]:
data_mode = "zscore"
if data_mode == "raw":
    subject_to_cache_dir = {
        "A": repo_root / "derivatives" / "preprocessing" / "secundo" / "baseline_only",
        "B": repo_root / "derivatives" / "preprocessing" / "gus" / "baseline_only",
    }
elif data_mode in {"mean_divide", "zscore"}:
    subject_to_cache_dir = {
        "A": repo_root
        / "derivatives"
        / "preprocessing"
        / "secundo"
        / "baseline_only_standardized"
        / data_mode,
        "B": repo_root
        / "derivatives"
        / "preprocessing"
        / "gus"
        / "baseline_only_standardized"
        / data_mode,
    }
else:
    raise ValueError(f"Unsupported data_mode: {data_mode}")

In [None]:
combined_manifest = create_split_manifest_multi(
    subject_to_cache_dir=subject_to_cache_dir,
    split_ratio=0.8,
    seed=42,
    pattern="baseline_*.npz",
    manifest_path=repo_root / "derivatives" / "preprocessing" / "splits_multi.json",
)
ds_train_combined = FUSForecastWindowDataset(
    manifest_path=combined_manifest,
    split="train",
    window_size=8,
    pred_horizon=2,
    stride=1,
)
ds_test_combined = FUSForecastWindowDataset(
    manifest_path=combined_manifest,
    split="test",
    window_size=8,
    pred_horizon=2,
    stride=1,
)
print("combined lens:", len(ds_train_combined), len(ds_test_combined))
x, y = ds_train_combined[0]
print("combined sample:", tuple(x.shape), tuple(y.shape))

In [None]:
single_manifest = create_split_manifest(
    cache_dir=repo_root
    / "derivatives"
    / "preprocessing"
    / "secundo"
    / "baseline_only_standardized"
    / data_mode,
    split_ratio=0.8,
    seed=42,
    pattern="baseline_*.npz",
    manifest_path=repo_root
    / "derivatives"
    / "preprocessing"
    / "splits_single_secundo.json",
)
ds_train_single = FUSForecastWindowDataset(
    manifest_path=single_manifest,
    split="train",
    window_size=8,
    pred_horizon=2,
    stride=1,
)
ds_test_single = FUSForecastWindowDataset(
    manifest_path=single_manifest,
    split="test",
    window_size=8,
    pred_horizon=2,
    stride=1,
)
print("single lens:", len(ds_train_single), len(ds_test_single))
x, y = ds_train_single[0]
print("single sample:", tuple(x.shape), tuple(y.shape))