In [None]:

"""
Cell 1:
Final script to convert labels (in matlab format) to pkl files, with 60s before timestamps
"""

import os
import pickle
import numpy as np
from glob import glob
from scipy.io import loadmat
import h5py
from temporaldata import RegularTimeSeries, Interval
    
# === Paths ===
BASE_DIR  = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Matlab_data_processed")
LABEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "processed_data", "labels")
os.makedirs(LABEL_DIR, exist_ok=True)

# === MATLAB struct helper for HDF5 mats ===
class MatlabStruct:
    def __init__(self, entries):
        for k, v in entries.items():
            setattr(self, k, v)

# === Load MATLAB ftdata (v7.2 & v7.3) ===
def load_ftdata(patient: str, session: str):
    mat_path = os.path.join(BASE_DIR, patient, f"{session}.mat")
    try:
        mat = loadmat(mat_path, struct_as_record=False, squeeze_me=True)
        ft = mat.get('ftdata') or mat.get('ftData')
        if ft is None:
            raise KeyError("ftdata not found")
        return ft
    except Exception:
        with h5py.File(mat_path, 'r') as f:
            grp = f.get('ftdata') or f.get('ftData')
            if grp is None:
                raise KeyError("ftdata not found")
            trial_refs = grp['trial'][()]
            fs         = float(np.array(grp['fsample']).item())
            won_grp    = grp['wonideets']
            trial_list = [np.array(f[r]).T for r in trial_refs.flatten()]
            # behavioral fields (including timestamp)
            won = {}
            for fld in [
                'ratingObsession', 'ratingCompulsion', 'ratingAnxiety',
                'ratingEnergy', 'ratingDepressions', 'ratingDistress',
                'isstim', 'timestamp'
            ]:
                arr = won_grp.get(fld)
                won[fld] = np.array(arr).flatten() if arr is not None else None
            return MatlabStruct({
                'trial': trial_list,
                'wonideets': MatlabStruct(won),
                'fsample': fs
            })

# === Process a single patient/session ===
def process_session(patient: str, session: str):
    ft     = load_ftdata(patient, session)
    trials = ft.trial
    fs     = ft.fsample

    # build full-time series
    data = np.concatenate(trials, axis=1).T  # samples x channels
    rt   = RegularTimeSeries(
        raw=data,
        sampling_rate=fs,
        domain=Interval(0.0, data.shape[0] / fs)
    )

    # compute intervals from timestamps (timestamp = end sample index)
    ts = getattr(ft.wonideets, 'timestamp', None)
    if ts is not None and len(ts) > 0:
        ts = np.asarray(ts, dtype=float)
        # ends in seconds
        ends   = ts / fs
        # starts are 60 seconds before each end
        starts = (ts - 60 * fs) / fs
    else:
        # fallback: equal-duration intervals from ftdata.time
        times = getattr(ft, 'time', None)
        if times:
            n    = len(times)
            dur  = times[0][-1] - times[0][0]
            starts = np.arange(n) * dur
            ends   = starts + dur
        else:
            starts = np.array([0.])
            ends   = np.array([0.])

    # extract ratings and stim flag
    won         = ft.wonideets
    rating_keys = [
        'ratingObsession', 'ratingCompulsion', 'ratingAnxiety',
        'ratingEnergy', 'ratingDepressions', 'ratingDistress'
    ]
    ratings = []
    for key in rating_keys:
        arr = getattr(won, key, None)
        if arr is not None and len(arr) == len(starts):
            ratings.append(np.asarray(arr, dtype=float))
        else:
            ratings.append(np.full(len(starts), np.nan))
    stim_arr = np.asarray(getattr(won, 'isstim', np.zeros(len(starts))), dtype=int)

    # stack labels and build Interval
    label_arr = np.stack(ratings + [stim_arr], axis=1)
    iv = Interval(
        start=starts,
        end=ends,
        label=label_arr,
        timekeys=['start', 'end']
    )

    # save Interval pickle with new suffix
    out_file = os.path.join(
        LABEL_DIR,
        f"{patient}_{session}_intervals_new.pkl"
    )
    with open(out_file, 'wb') as pf:
        pickle.dump(iv, pf)
    print(f"Saved: {out_file}")

    # preview dump
    loaded = pickle.load(open(out_file, 'rb'))
    print(f"Preview of {patient}_{session}_intervals_new:")
    print(
        "  starts[:3]:", loaded.start[:3],
        "ends[:3]:", loaded.end[:3],
        "labels[0]:", loaded.label[0]
    )

    return rt, iv

# === Batch Processing ===
def main(patients=None):
    if patients is None:
        # sort Patient1, Patient2, ... numerically
        patients = sorted(
            [d for d in os.listdir(BASE_DIR) if d.startswith("Patient")],
            key=lambda n: int(n.replace("Patient", ""))
        )
    for patient in patients:
        sessions = sorted(
            f[:-4]
            for f in os.listdir(os.path.join(BASE_DIR, patient))
            if f.endswith('.mat')
        )
        for session in sessions:
            try:
                process_session(patient, session)
            except Exception as e:
                print(f"Error {patient}/{session}: {e}")

if __name__ == '__main__':
    main()

In [None]:
"""
Cell 2:
Script: works to split a large Interval object into smaller subfiles based on specified counts. Used to process if multiple sessions per file 
"""

# === Configuration ===
# Specify patient and day
PATIENT = "Patient2"
DAY     = "day5"
# Specify how many intervals per subfile, in order
split_counts = [2,15,5,5]  # [3,4,13] for a total of 20 intervals

# Paths
BASE_DIR   = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation"
INPUT_DIR  = os.path.join(BASE_DIR, "processed_data", "labels")
OUTPUT_DIR = os.path.join(BASE_DIR, "labels_jonathan")

# Construct input filename
filename    = f"{PATIENT}_{DAY}_intervals_new.pkl"
input_path  = os.path.join(INPUT_DIR, filename)

# Load the full Interval object
with open(input_path, "rb") as f:
    iv = pickle.load(f)

total_intervals = len(iv.start)
if sum(split_counts) != total_intervals:
    raise ValueError(f"Sum of split_counts ({sum(split_counts)}) != total intervals ({total_intervals})")

# Compute cumulative indices for slicing
boundaries = np.cumsum([0] + split_counts)

# Split and save
for i, count in enumerate(split_counts, start=1):
    start_idx = boundaries[i-1]
    end_idx   = boundaries[i]
    iv_sub = Interval(
        start=iv.start[start_idx:end_idx],
        end=iv.end[start_idx:end_idx],
        label=iv.label[start_idx:end_idx],
        timekeys=["start","end"]
    )
    out_name = f"{PATIENT}_{DAY}_part{i}.pkl"
    out_path = os.path.join(OUTPUT_DIR, out_name)
    with open(out_path, "wb") as of:
        pickle.dump(iv_sub, of)
    print(f"Saved {count} intervals to {out_path}")

In [None]:
"""
Cell 3:
Script that can be used to shift a sample by the correct amount of time.

"""


# === Configuration ===
# Specify the file, offset (in Hz), and sampling rate (Hz)
PKL_FILE = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/labels_jonathan/P2_D5_Tsymptom_provocation_SUPENNS001R03.pkl"
OFFSET_HZ = 2592480 + 4320120 + 888480 # how many samples to shift back
FSAMPLE   = 1200.0  #sampling rate in Hz

# 1) Compute the time offset in seconds
offset_sec = OFFSET_HZ / FSAMPLE

# 2) Load the existing Interval object
with open(PKL_FILE, "rb") as f:
    iv = pickle.load(f)

# 3) Shift start and end times
new_starts = iv.start - offset_sec
new_ends   = iv.end   - offset_sec

# 4) (Optional) warn if any intervals go negative
n_neg = (new_starts < 0).sum()
if n_neg:
    print(f"{n_neg} intervals start < 0 after shifting by {offset_sec:.3f}s")

# 5) Build a new shifted Interval
iv_shifted = Interval(
    start=new_starts,
    end=new_ends,
    label=iv.label,
    timekeys=['start','end']
)

# 6) Overwrite the original file with the shifted version
with open(PKL_FILE, "wb") as f:
    pickle.dump(iv_shifted, f)

print(f"Shifted '{os.path.basename(PKL_FILE)}' by {offset_sec:.3f} seconds.")


In [None]:
"""
Cell 4:

quick script to load and inspect an Interval object from a .pkl file

*assumes that pkl files do not have dicts yet.

"""



# === Specify the path to your Interval .pkl file here ===
PKL_PATH = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/labels_jonathan/P2_D5_Tsymptom_provocation_SUPENNS001R03.pkl"

# 1) Load the Interval object
with open(PKL_PATH, "rb") as f:
    iv = pickle.load(f)

# 2) Inspect available attributes (optional)
print("Loaded object type:", type(iv))
print("Attributes:", [a for a in dir(iv) if not a.startswith("_")])

# 3) Iterate and print each interval’s start, end, and labels
n = len(iv.start)
print(f"\nTotal intervals: {n}\n")
for i in range(n):
    start = iv.start[i]
    end   = iv.end[i]
    label = iv.label[i]
    print(f"Interval {i}: start={start}, end={end}, labels={label}")


NameError: name 'pickle' is not defined

At this point, we have correct pkl files corresponding to sessions. "excel_to_pkl.ipnyb" ends at the same point. Both files meet here.

Together, "matlab_to_pkl" and "excel_to_pkl" represent a preprocessing pipeline.

In [6]:
"""Cell 5: 

script to make splits from raw pkl files for a specific patient
saves to labels_jonathan

* Also filters out the intervals that have start < 0 or end < 60 seconds, as these are invalid.
"""

import os
import pickle
import numpy as np
from temporaldata import Interval

# === Configuration ===
TRAIN_FRAC = 0.8
VAL_FRAC   = 0.1
TEST_FRAC  = 0.1
PATIENT    = '3'

# Paths
BASE_DIR   = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation"
INPUT_DIR  = os.path.join(BASE_DIR, "processed_data", "labels")
OUTPUT_DIR = os.path.join(BASE_DIR, "labels_jonathan")  # already exists

def slice_iv(iv, idxs):
    return Interval(
        start=iv.start[idxs],
        end=iv.end[idxs],
        label=iv.label[idxs],
        timekeys=['start', 'end']
    )

# Process each pkl for this patient
for fname in os.listdir(INPUT_DIR):
    if not (fname.startswith(f'P{PATIENT}_') and fname.endswith('.pkl')):
        continue

    session_id = fname[:-4]  # strip ".pkl"
    in_path    = os.path.join(INPUT_DIR, fname)
    out_path   = os.path.join(OUTPUT_DIR, fname)

    print(f"Loading intervals from {fname}...")

    # 1) load the original Interval object
    with open(in_path, 'rb') as f:
        iv_orig = pickle.load(f)

    # --- Filter out invalid intervals ---
    mask_valid = (iv_orig.start >= 0) & (iv_orig.end >= 60)
    invalid_count = len(iv_orig.start) - int(mask_valid.sum())
    if invalid_count > 0:
        print(f"  Removed {invalid_count} intervals (start<0 or end<60) from {fname}")
    iv_orig = Interval(
        start=iv_orig.start[mask_valid],
        end=iv_orig.end[mask_valid],
        label=iv_orig.label[mask_valid],
        timekeys=['start', 'end']
    )

    # 2) Determine split sizes
    n = len(iv_orig.start)
    val_count   = int(np.floor(n * VAL_FRAC))
    test_count  = int(np.floor(n * TEST_FRAC))
    train_count = n - val_count - test_count

    # 3) Shuffle indices
    indices   = np.random.permutation(n)
    train_idx = indices[:train_count]
    val_idx   = indices[train_count:train_count + val_count]
    test_idx  = indices[train_count + val_count:train_count + val_count + test_count]

    # 4) Build split Intervals
    iv_train = slice_iv(iv_orig, train_idx)
    iv_val   = slice_iv(iv_orig, val_idx)
    iv_test  = slice_iv(iv_orig, test_idx)

    # 5) Prepare splits dict
    splits = {
        'train': {f'halpern_ocd/{session_id}': iv_train},
        'val':   {f'halpern_ocd/{session_id}': iv_val},
        'test':  {f'halpern_ocd/{session_id}': iv_test},
    }

    # 6) Save to output directory
    with open(out_path, 'wb') as f:
        pickle.dump(splits, f)

    print(f"Saved new splits to {out_path}: {train_count} train, {val_count} val, {test_count} test")

print(f"All splits complete for patient {PATIENT}.")

Loading intervals from P3_D4_Tsymptom_provocation_SUPENNS001R02.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/labels_jonathan/P3_D4_Tsymptom_provocation_SUPENNS001R02.pkl: 24 train, 2 val, 2 test
Loading intervals from P3_D5_Tsymptom_provocation_SUPENNS001R01.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/labels_jonathan/P3_D5_Tsymptom_provocation_SUPENNS001R01.pkl: 22 train, 2 val, 2 test
Loading intervals from P3_D5_Tsymptom_provocation_SUPENNS001R01part2.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/labels_jonathan/P3_D5_Tsymptom_provocation_SUPENNS001R01part2.pkl: 23 train, 2 val, 2 test
Loading intervals from P3_D3_Tsymptom_provocation_SUPENNS001R01part2.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/labels_jonathan/P3_D3_Tsymptom_provocation_SUPENNS001R01part2.pkl: 15 train, 1 val, 1 test
Loading intervals from P3_D3_Tsymptom_provocation_SUPENNS001R02.pkl...
S

In [None]:
"""Cell 5a: 

script to make sequential splits from raw pkl files for a specific patient
saves to sequential_splits

* Also filters out the intervals that have start < 0 or end < 60 seconds, as these are invalid.
"""

import os
import pickle
import numpy as np
from temporaldata import Interval

# === Configuration ===
TRAIN_FRAC = 0.8
VAL_FRAC   = 0.1
TEST_FRAC  = 0.1
PATIENT    = '5'

# Paths
BASE_DIR   = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation"
INPUT_DIR  = os.path.join(BASE_DIR, "processed_data", "labels")
OUTPUT_DIR = os.path.join(BASE_DIR, "sequential_splits")  # already exists

def slice_iv(iv, idxs):
    return Interval(
        start=iv.start[idxs],
        end=iv.end[idxs],
        label=iv.label[idxs],
        timekeys=['start', 'end']
    )

# Process each pkl for this patient
for fname in os.listdir(INPUT_DIR):
    if not (fname.startswith(f'P{PATIENT}_') and fname.endswith('.pkl')):
        continue

    session_id = fname[:-4]  # strip ".pkl"
    in_path    = os.path.join(INPUT_DIR, fname)
    out_path   = os.path.join(OUTPUT_DIR, fname)

    print(f"Loading intervals from {fname}...")

    # 1) load the original Interval object
    with open(in_path, 'rb') as f:
        iv_orig = pickle.load(f)

    # --- Filter out invalid intervals ---
    mask_valid = (iv_orig.start >= 0) & (iv_orig.end >= 60)
    invalid_count = len(iv_orig.start) - int(mask_valid.sum())
    if invalid_count > 0:
        print(f"  Removed {invalid_count} intervals (start<0 or end<60) from {fname}")
    iv_orig = Interval(
        start=iv_orig.start[mask_valid],
        end=iv_orig.end[mask_valid],
        label=iv_orig.label[mask_valid],
        timekeys=['start', 'end']
    )

    # 2) Determine split sizes
    n = len(iv_orig.start)
    val_count   = int(np.floor(n * VAL_FRAC))
    test_count  = int(np.floor(n * TEST_FRAC))
    train_count = n - val_count - test_count

    # 3) Sequential split (first TRAIN_FRAC, then VAL_FRAC, then TEST_FRAC)
    indices   = np.arange(n)
    train_idx = indices[:train_count]
    val_idx   = indices[train_count:train_count + val_count]
    test_idx  = indices[train_count + val_count:]

    # 4) Build split Intervals
    iv_train = slice_iv(iv_orig, train_idx)
    iv_val   = slice_iv(iv_orig, val_idx)
    iv_test  = slice_iv(iv_orig, test_idx)

    # 5) Prepare splits dict
    splits = {
        'train': {f'halpern_ocd/{session_id}': iv_train},
        'val':   {f'halpern_ocd/{session_id}': iv_val},
        'test':  {f'halpern_ocd/{session_id}': iv_test},
    }

    # 6) Save to output directory
    with open(out_path, 'wb') as f:
        pickle.dump(splits, f)

    print(f"Saved new splits to {out_path}: {train_count} train, {val_count} val, {test_count} test")

print(f"All splits complete for patient {PATIENT}.")


Loading intervals from P5_D4_Tsymptom_provocation_SUPENNS001R02.part2.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/sequential_splits/P5_D4_Tsymptom_provocation_SUPENNS001R02.part2.pkl: 16 train, 1 val, 1 test
Loading intervals from P5_D4_Tsymptom_provocation_SUPENNS001R02.part3.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/sequential_splits/P5_D4_Tsymptom_provocation_SUPENNS001R02.part3.pkl: 8 train, 0 val, 0 test
Loading intervals from P5_D4_Tsymptom_provocation_SUPENNS001R02.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/sequential_splits/P5_D4_Tsymptom_provocation_SUPENNS001R02.pkl: 17 train, 2 val, 2 test
Loading intervals from P5_D5_Tsymptom_provocation_SUPENNS001R01.part2.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/sequential_splits/P5_D5_Tsymptom_provocation_SUPENNS001R01.part2.pkl: 16 train, 2 val, 2 test
Loading intervals from P5_D5_Tsymptom_provocation

In [15]:
"""Cell 5b: 

script to make splits from raw pkl files for a specific patient
saves to labels_jonathan

* Also filters out the intervals that have start < 0 or end < 60 seconds, as these are invalid.
"""

import os
import pickle
import numpy as np
from temporaldata import Interval

# === Configuration ===
TRAIN_FRAC = 0
VAL_FRAC   = 0
TEST_FRAC  = 1
PATIENT    = '5'
DAY = "5"

# Paths
BASE_DIR   = "/vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation"
INPUT_DIR  = os.path.join(BASE_DIR, "processed_data", "labels")
OUTPUT_DIR = os.path.join(BASE_DIR, "holdout_splits")  # already exists

def slice_iv(iv, idxs):
    return Interval(
        start=iv.start[idxs],
        end=iv.end[idxs],
        label=iv.label[idxs],
        timekeys=['start', 'end']
    )

# Process each pkl for this patient
for fname in os.listdir(INPUT_DIR):
    if not (fname.startswith(f'P{PATIENT}_D{DAY}') and fname.endswith('.pkl')):
        continue

    session_id = fname[:-4]  # strip ".pkl"
    in_path    = os.path.join(INPUT_DIR, fname)
    out_path   = os.path.join(OUTPUT_DIR, fname)

    print(f"Loading intervals from {fname}...")

    # 1) load the original Interval object
    with open(in_path, 'rb') as f:
        iv_orig = pickle.load(f)

    # --- Filter out invalid intervals ---
    mask_valid = (iv_orig.start >= 0) & (iv_orig.end >= 60)
    invalid_count = len(iv_orig.start) - int(mask_valid.sum())
    if invalid_count > 0:
        print(f"  Removed {invalid_count} intervals (start<0 or end<60) from {fname}")
    iv_orig = Interval(
        start=iv_orig.start[mask_valid],
        end=iv_orig.end[mask_valid],
        label=iv_orig.label[mask_valid],
        timekeys=['start', 'end']
    )

    # 2) Determine split sizes
    n = len(iv_orig.start)
    val_count   = int(np.floor(n * VAL_FRAC))
    test_count  = int(np.floor(n * TEST_FRAC))
    train_count = n - val_count - test_count

    # 3) Shuffle indices
    indices   = np.random.permutation(n)
    train_idx = indices[:train_count]
    val_idx   = indices[train_count:train_count + val_count]
    test_idx  = indices[train_count + val_count:train_count + val_count + test_count]

    # 4) Build split Intervals
    iv_train = slice_iv(iv_orig, train_idx)
    iv_val   = slice_iv(iv_orig, val_idx)
    iv_test  = slice_iv(iv_orig, test_idx)

    # 5) Prepare splits dict
    splits = {
        'train': {f'halpern_ocd/{session_id}': iv_train},
        'val':   {f'halpern_ocd/{session_id}': iv_val},
        'test':  {f'halpern_ocd/{session_id}': iv_test},
    }

    # 6) Save to output directory
    with open(out_path, 'wb') as f:
        pickle.dump(splits, f)

    print(f"Saved new splits to {out_path}: {train_count} train, {val_count} val, {test_count} test")

print(f"All splits complete for patient {PATIENT} day {DAY}.")

Loading intervals from P5_D5_Tsymptom_provocation_SUPENNS001R01.part2.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/holdout_splits/P5_D5_Tsymptom_provocation_SUPENNS001R01.part2.pkl: 0 train, 0 val, 20 test
Loading intervals from P5_D5_Tsymptom_provocation_SUPENNS001R01.pkl...
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/holdout_splits/P5_D5_Tsymptom_provocation_SUPENNS001R01.pkl: 0 train, 0 val, 21 test
Loading intervals from P5_D5_Tsymptom_provocation_SUPENNS001R01.part3.pkl...
  Removed 1 intervals (start<0 or end<60) from P5_D5_Tsymptom_provocation_SUPENNS001R01.part3.pkl
Saved new splits to /vol/cortex/cd3/pesaranlab/OCD_Mapping_Foundation/holdout_splits/P5_D5_Tsymptom_provocation_SUPENNS001R01.part3.pkl: 0 train, 0 val, 6 test
All splits complete for patient 5 day 5.


In [None]:
"""
Cell 6: 


cell to make splits from data that already has splits. Now also filters out splits 
"""
import os
import pickle
import numpy as np
from temporaldata import Interval

# === Config===
TRAIN_FRAC = 0.8
VAL_FRAC   = 0.1
TEST_FRAC  = 0.1
PATIENT    = '5'

# Paths 
BASE_DIR  = os.getcwd()
LABEL_DIR = os.path.join(BASE_DIR, 'labels_jonathan')

def slice_iv(iv, idxs):
    return Interval(
        start=iv.start[idxs],
        end=iv.end[idxs],
        label=iv.label[idxs],
        timekeys=['start', 'end']P5_D5_Tsymptom_provocation_SUPENNS001R01
    )

for fname in os.listdir(LABEL_DIR):
    if not (fname.startswith(f'P{PATIENT}_') and fname.endswith('.pkl')):
        continue

    session_id = fname[:-4]  # strip ".pkl"
    pkl_path   = os.path.join(LABEL_DIR, fname)
    print(f"Loading splits from {fname}...")

    # 1) load the existing splits dict
    with open(pkl_path, 'rb') as f:
        old_splits = pickle.load(f)

    # 2) gather all Interval objects into one combined Interval
    all_intervals = []
    for split in ('train', 'val', 'test'):
        for iv in old_splits.get(split, {}).values():
            all_intervals.append(iv)

    starts = np.concatenate([iv.start for iv in all_intervals])
    ends   = np.concatenate([iv.end   for iv in all_intervals])
    labels = np.concatenate([iv.label for iv in all_intervals])

    iv_orig = Interval(start=starts, end=ends, label=labels, timekeys=['start', 'end'])



    n = len(iv_orig.start)

    # 3) compute new split sizes
    val_count   = int(np.floor(n * VAL_FRAC))
    test_count  = int(np.floor(n * TEST_FRAC))
    train_count = n - val_count - test_count

    # 4) shuffle and partition
    idxs      = np.random.permutation(n)
    train_idx = idxs[:train_count]
    val_idx   = idxs[train_count:train_count + val_count]
    test_idx  = idxs[train_count + val_count:train_count + val_count + test_count]

    iv_train = slice_iv(iv_orig, train_idx)
    iv_val   = slice_iv(iv_orig, val_idx)
    iv_test  = slice_iv(iv_orig, test_idx)

    # 5) assemble the new splits dict (new one doesnt have the halpern_ocd prefix)
    splits = {
        'train': {f'halpern_ocd/{session_id}': iv_train},
        'val':   {f'halpern_ocd/{session_id}': iv_val},
        'test':  {f'halpern_ocd/{session_id}': iv_test},
    }

    # 6) overwrite the original .pkl
    with open(pkl_path, 'wb') as f:
        pickle.dump(splits, f)

    print(f"Saved new splits for {session_id}:",
          f"{train_count} train,", f"{val_count} val,", f"{test_count} test")

print(f"All splits complete for patient {PATIENT}.")


SyntaxError: invalid syntax. Perhaps you forgot a comma? (3183398640.py, line 29)