In [None]:
#cell to make split for label data, only run when we need to make new splits from label data
# import os
# import pickle
# import numpy as np
# from temporaldata import Interval

# # === Configuration ===
# TRAIN_FRAC = 0.8
# VAL_FRAC   = 0.0
# TEST_FRAC  = 0.2
# # Choose patient number
# PATIENT = '5'

# # Paths
# BASE_DIR  = os.getcwd()  # current working directory
# LABEL_DIR = os.path.join(BASE_DIR, 'processed_data', 'labels')

# # Process each pkl for this patient
# for fname in os.listdir(LABEL_DIR):
#     if not fname.startswith(f'P{PATIENT}_') or not fname.endswith('.pkl'):
#         continue

#     session_id = fname[:-4]  # strip '.pkl'
#     path = os.path.join(LABEL_DIR, fname)
#     print(f"Loading intervals from {fname}...")
#     with open(path, 'rb') as f:
#         iv_orig = pickle.load(f)

#     # Total intervals
#     n = len(iv_orig.start)
#     # Determine counts (floor for val & test)
#     test_count  = int(np.floor(n * TEST_FRAC))
#     val_count   = int(np.floor(n * VAL_FRAC))
#     train_count = n - val_count - test_count

#     # Shuffle indices
#     indices = np.random.permutation(n)
#     train_idx = indices[:train_count]
#     val_idx   = indices[train_count:train_count + val_count]
#     test_idx  = indices[train_count + val_count:train_count + val_count + test_count]

#     # Build split Intervals
#     def slice_iv(iv, idxs):
#         return Interval(
#             start=iv.start[idxs],
#             end=iv.end[idxs],
#             label=iv.label[idxs],
#             timekeys=['start', 'end']
#         )

#     iv_train = slice_iv(iv_orig, train_idx)
#     iv_val   = slice_iv(iv_orig, val_idx)
#     iv_test  = slice_iv(iv_orig, test_idx)

#     # Prepare dictionary
#     splits = {
#         'train': {f'halpern_ocd/{session_id}': iv_train},
#         'val':   {f'halpern_ocd/{session_id}': iv_val},
#         'test':  {f'halpern_ocd/{session_id}': iv_test},
#     }

#     # Overwrite the original file with the splits dict
#     with open(path, 'wb') as f:
#         pickle.dump(splits, f)

#     # Confirmation
#     print(f"Saved split for {session_id}: {train_count} train, {val_count} val, {test_count} test")

# print(f"All splits complete for patient {PATIENT}.")



import os
import pickle
import numpy as np
from temporaldata import Interval

# === Config===
TRAIN_FRAC = 0.8
VAL_FRAC   = 0.1
TEST_FRAC  = 0.1
PATIENT    = '5'

# Paths 
BASE_DIR  = os.getcwd()
LABEL_DIR = os.path.join(BASE_DIR, 'processed_data', 'labels')

def slice_iv(iv, idxs):
    return Interval(
        start=iv.start[idxs],
        end=iv.end[idxs],
        label=iv.label[idxs],
        timekeys=['start', 'end']
    )

for fname in os.listdir(LABEL_DIR):
    if not (fname.startswith(f'P{PATIENT}_') and fname.endswith('.pkl')):
        continue

    session_id = fname[:-4]  # strip ".pkl"
    pkl_path   = os.path.join(LABEL_DIR, fname)
    print(f"Loading splits from {fname}...")

    # 1) load the existing splits dict
    with open(pkl_path, 'rb') as f:
        old_splits = pickle.load(f)

    # 2) gather all Interval objects into one combined Interval
    all_intervals = []
    for split in ('train', 'val', 'test'):
        for iv in old_splits.get(split, {}).values():
            all_intervals.append(iv)

    starts = np.concatenate([iv.start for iv in all_intervals])
    ends   = np.concatenate([iv.end   for iv in all_intervals])
    labels = np.concatenate([iv.label for iv in all_intervals])

    iv_orig = Interval(start=starts, end=ends, label=labels, timekeys=['start', 'end'])
    n = len(iv_orig.start)

    # 3) compute new split sizes
    val_count   = int(np.floor(n * VAL_FRAC))
    test_count  = int(np.floor(n * TEST_FRAC))
    train_count = n - val_count - test_count

    # 4) shuffle and partition
    idxs      = np.random.permutation(n)
    train_idx = idxs[:train_count]
    val_idx   = idxs[train_count:train_count + val_count]
    test_idx  = idxs[train_count + val_count:train_count + val_count + test_count]

    iv_train = slice_iv(iv_orig, train_idx)
    iv_val   = slice_iv(iv_orig, val_idx)
    iv_test  = slice_iv(iv_orig, test_idx)

    # 5) assemble the new splits dict (new one doesnt have the halpern_ocd prefix)
    new_splits = {
        'train': {session_id: iv_train},
        'val':   {session_id: iv_val},
        'test':  {session_id: iv_test},
    }

    # 6) overwrite the original .pkl
    with open(pkl_path, 'wb') as f:
        pickle.dump(new_splits, f)

    print(f"Saved new splits for {session_id}:",
          f"{train_count} train,", f"{val_count} val,", f"{test_count} test")

print(f"All splits complete for patient {PATIENT}.")




Loading splits from P5_D4_Tsymptom_provocation_SUPENNS001R02.part2.pkl...
Saved new splits for P5_D4_Tsymptom_provocation_SUPENNS001R02.part2: 16 train, 1 val, 1 test
Loading splits from P5_D4_Tsymptom_provocation_SUPENNS001R02.part3.pkl...
Saved new splits for P5_D4_Tsymptom_provocation_SUPENNS001R02.part3: 8 train, 0 val, 0 test
Loading splits from P5_D4_Tsymptom_provocation_SUPENNS001R02.pkl...
Saved new splits for P5_D4_Tsymptom_provocation_SUPENNS001R02: 17 train, 2 val, 2 test
Loading splits from P5_D5_Tsymptom_provocation_SUPENNS001R01.part2.pkl...
Saved new splits for P5_D5_Tsymptom_provocation_SUPENNS001R01.part2: 16 train, 2 val, 2 test
Loading splits from P5_D5_Tsymptom_provocation_SUPENNS001R01.pkl...
Saved new splits for P5_D5_Tsymptom_provocation_SUPENNS001R01: 17 train, 2 val, 2 test
Loading splits from P5_D3_Tsymptom_provocation_SUPENNS001R01.part2.pkl...
Saved new splits for P5_D3_Tsymptom_provocation_SUPENNS001R01.part2: 15 train, 1 val, 1 test
Loading splits from P5_

In [39]:

"""
Cell to check the splits structure, validate the splits, and print counts.
"""
# 1) Point to your labels directory and a session to check
label_dir = os.path.join(os.getcwd(), 'processed_data', 'labels')
session_id = 'P5_D4_Tsymptom_provocation_SUPENNS001R02.part3'
pkl_path = os.path.join(label_dir, f"{session_id}.pkl")

# 2) Load the splits dict
with open(pkl_path, 'rb') as f:
    splits = pickle.load(f)

# 3) Print out the structure and interval counts
print("Splits found:", list(splits.keys()))
for split, split_dict in splits.items():
    for sid, iv in split_dict.items():
        count = len(iv.start) if hasattr(iv, 'start') else 'N/A'
        print(f"{split:<5} | session = {sid:<40} | intervals = {count}")


Splits found: ['train', 'val', 'test']
train | session = P5_D4_Tsymptom_provocation_SUPENNS001R02.part3 | intervals = 8
val   | session = P5_D4_Tsymptom_provocation_SUPENNS001R02.part3 | intervals = 0
test  | session = P5_D4_Tsymptom_provocation_SUPENNS001R02.part3 | intervals = 0


In [53]:
# Cell to make data loaders

# START HERE FOR MODEL TRAINING

import os
import pickle
import yaml
import torch
from torch.utils.data import DataLoader

from torch_brain.data import Dataset, chain
from torch_brain.data.collate import collate
from torch_brain.data.sampler import (
    RandomFixedWindowSampler,
    SequentialFixedWindowSampler,
)



cfg = yaml.safe_load(open("config.yaml"))
# cfg is a list with one element; that element has a key "selection" mapping to a dict,
# and that dict has your "sessions" list.
session_list = cfg[0]["selection"][0]["sessions"]

# 2) Define file roots (update these to where your files actually live)
base_root  = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation"
h5_root    = os.path.join(
    base_root,
    "processed_data_upd",
    "processed_data"
)
label_root = os.path.join(
    base_root,
    "processed_data",
    "labels"
)

# 3) Create the Dataset object
total_dataset = Dataset(
    root=h5_root,
    split=None,
    config="config.yaml",  
    # drop the subject prefix — only use the session.id
    session_id_prefix_fn=lambda d: "",
)   



# 4) Read each .pkl to collect train/val/test intervals
sampling_intervals = {"train": {}, "val": {}, "test": {}}

for session_id in session_list:
    pkl_path = os.path.join(label_root, session_id + ".pkl")
    with open(pkl_path, "rb") as f:
        splits = pickle.load(f)

    for split in ("train", "val", "test"):
        iv = splits.get(split, {}).get(session_id)
        if iv is not None:
            sampling_intervals[split][session_id] = iv


# 5) Hyperparameters (hard‑coded)
batch_size     = 10
num_workers    = 1
prefetch_factor= 1
window_length  = 1.0  # seconds

# 6) Create samplers for each split
train_sampler = RandomFixedWindowSampler(
    sampling_intervals=sampling_intervals["train"],
    window_length=1,
    generator=torch.Generator().manual_seed(0),
)
val_sampler = SequentialFixedWindowSampler(
    sampling_intervals=sampling_intervals["val"],
    window_length=1,
)
test_sampler = SequentialFixedWindowSampler(
    sampling_intervals=sampling_intervals["test"],
    window_length=1,
)

# 7) Wrap samplers in DataLoaders
train_loader = DataLoader(
    total_dataset, batch_size=batch_size, sampler=train_sampler,
    collate_fn=collate, num_workers=num_workers,
    drop_last=True, persistent_workers=True,
    prefetch_factor=prefetch_factor,
)
val_loader = DataLoader(
    total_dataset, batch_size=batch_size, sampler=val_sampler,
    collate_fn=collate, num_workers=num_workers,
    drop_last=False, persistent_workers=True,
    prefetch_factor=prefetch_factor,
)
test_loader = DataLoader(
    total_dataset, batch_size=batch_size, sampler=test_sampler,
    collate_fn=collate, num_workers=num_workers,
    drop_last=False, persistent_workers=True,
    prefetch_factor=prefetch_factor,
)

# — now train_loader / val_loader / test_loader pull exactly from:
#    - HD5 files under processed_data_upd
#    - interval splits under processed_data/labels




# for session in total_dataset.get_session_ids():
#     with open(filename, 'rb') as f:
#         intervals = pickle.load(f)
#         train_sampling_intervals[session] = intervals['train'][session]

# # Define train/val/test data loaders
# train_sampler = RandomFixedWindowSampler(
#     sampling_intervals=train_sampling_intervals,
#     window_length=1.0,  # seconds
#     generator=torch.Generator().manual_seed(0),
# )
# train_loader = DataLoader(
#     dataset=total_dataset,
#     batch_size=10,
#     sampler=train_sampler,
#     collate_fn=collate,
#     num_workers=1,
#     drop_last=True,
#     persistent_workers=True,
#     prefetch_factor=1,
# )


In [54]:
# Diagnostic — list every recording your Dataset knows about
print("Dataset sessions:")
for sid in total_dataset.get_session_ids():
    print("  ", sid)


Dataset sessions:
   P5_D3_Tsymptom_provocation_SUPENNS001R01
   P5_D3_Tsymptom_provocation_SUPENNS001R01.part2
   P5_D3_Tsymptom_provocation_SUPENNS001R01.part3
   P5_D3_Tsymptom_provocation_SUPENNS001R01.part4
   P5_D4_TSymptpm_prov_SUPENNS001R02
   P5_D4_TSymptpm_prov_SUPENNS001R02.part2
   P5_D4_TSymptpm_prov_SUPENNS001R02.part3
   P5_D5_Tsymptom_provocation_SUPENNS001R01
   P5_D5_Tsymptom_provocation_SUPENNS001R01.part2
   P5_D5_Tsymptom_provocation_SUPENNS001R01.part3


In [56]:
with open("labels_jonathan/P5_D4_Tsymptom_provocation_SUPENNS001R02.pkl", "rb") as f:
    d = pickle.load(f)
print("train keys:", d["train"].keys())


train keys: dict_keys(['P5_D4_Tsymptom_provocation_SUPENNS001R02'])


In [None]:
# Cell — inspect train interval durations

# — Diagnostic cell — verify sampling_intervals exists and has data
try:
    print("sampling_intervals keys:", sampling_intervals.keys())
    print("  train sessions:", list(sampling_intervals["train"].keys()))
    print("  # train sessions:", len(sampling_intervals["train"]))
except NameError:
    print("sampling_intervals is not defined in this scope.")


for session_id, intervals in sampling_intervals["train"].items():
    # compute duration of each interval
    durations = [end - start for start, end in intervals]
    if not durations:
        print(f"{session_id}: no intervals")
        continue
    print(f"Session {session_id}:")
    print(f"  # intervals = {len(durations)}")
    print(f"  Example durations (first 5) = {durations[:5]}")
    print(f"  Min = {min(durations)}, Max = {max(durations)}, Mean = {sum(durations)/len(durations):.3f}")
    print()
    # break after first session if you only want one example
    break


sampling_intervals keys: dict_keys(['train', 'val', 'test'])
  train sessions: ['P5_D5_Tsymptom_provocation_SUPENNS001R01', 'P5_D5_Tsymptom_provocation_SUPENNS001R01.part3', 'P5_D5_Tsymptom_provocation_SUPENNS001R01.part2', 'P5_D4_Tsymptom_provocation_SUPENNS001R02', 'P5_D4_Tsymptom_provocation_SUPENNS001R02.part3', 'P5_D4_Tsymptom_provocation_SUPENNS001R02.part2', 'P5_D3_Tsymptom_provocation_SUPENNS001R01', 'P5_D3_Tsymptom_provocation_SUPENNS001R01.part4', 'P5_D3_Tsymptom_provocation_SUPENNS001R01.part3', 'P5_D3_Tsymptom_provocation_SUPENNS001R01.part2']
  # train sessions: 10
Session P5_D5_Tsymptom_provocation_SUPENNS001R01:
  # intervals = 17
  Example durations (first 5) = [60.0, 60.0, 60.0, 60.0, 60.0]
  Min = 60.0, Max = 60.0, Mean = 60.000



In [58]:
# Cell — run one iteration of train_loader
train_iter = iter(train_loader)
try:
    batch = next(train_iter)
except StopIteration:
    raise RuntimeError("train_loader is empty—check your sampler and intervals.")

# If your Dataset returns (inputs, labels):
if isinstance(batch, (list, tuple)) and len(batch) == 2:
    inputs, labels = batch
    print("Inputs type:", type(inputs), "shape:", getattr(inputs, "shape", None))
    print("Labels type:", type(labels), "shape:", getattr(labels, "shape", None))
# If your Dataset returns a dict:
elif isinstance(batch, dict):
    for k, v in batch.items():
        print(f"{k}: type={type(v)}, shape={getattr(v, 'shape', None)}")
else:
    print("Batch content:", batch)


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jmehrotra/jmehrotra_ocd_venv/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/jmehrotra/jmehrotra_ocd_venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jmehrotra/jmehrotra_ocd_venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jmehrotra/jmehrotra_ocd_venv/lib/python3.10/site-packages/torch_brain/data/dataset.py", line 457, in __getitem__
    sample = self.get(index.recording_id, index.start, index.end)
  File "/home/jmehrotra/jmehrotra_ocd_venv/lib/python3.10/site-packages/torch_brain/data/dataset.py", line 298, in get
    data = copy.copy(self._data_objects[recording_id])
KeyError: 'P5_D4_Tsymptom_provocation_SUPENNS001R02.part3'
