In [63]:
from scipy.io import loadmat
import pandas as pd
import numpy as np
import h5py
import matplotlib.pyplot as plt
import os
import re

In [61]:
def spontaneous_intervals(
    data_path,                   # path to behaviour .mat
    file_id,                     # 'meas00' / 'meas01' ... or int (0 -> meas00)
    output="frames",             # 'frames' or 'time'
    buffers=None,                # per-event guard bands (s): {name: (pre, post)}
    verbose=False
):
    """
    Return Nx2 spontaneous intervals for a SPECIFIC recording (row) in MasterN.

    file_id:
      - str like 'meas00', 'meas03', etc., or
      - int like 0, 1, 2 (interpreted as meas00, meas01, meas02).
    The corresponding video column uses 1-based numbering in the filename:
      meas00 -> ..._01.mp4, meas01 -> ..._02.mp4, etc.
    """
    # ---------- load and locate the correct row ----------
    mat = loadmat(data_path, squeeze_me=False, struct_as_record=False)
    master = mat["MasterN"]     # (n_rows, 13) cell array

    # desired mp4 index (1-based)
    if isinstance(file_id, str):
        m = re.search(r'(\d+)$', file_id.strip())
        if not m:
            raise ValueError(f"file_id '{file_id}' should end with digits (e.g., 'meas00').")
        meas_num = int(m.group(1))
    else:
        meas_num = int(file_id)
    mp4_num = meas_num + 1  # '..._01.mp4' for meas00

    # helpers to unwrap MATLAB cell strings
    def to_str(x):
        if isinstance(x, str):
            return x
        if isinstance(x, np.ndarray) and x.size == 1:
            v = x.item()
            return v if isinstance(v, str) else str(v)
        return str(x)

    # column indices (0-based): video=1, TimeStamps=5
    video_col = 1
    ts_col    = 5

    row_idx = None
    pat = re.compile(r'_(\d+)\.mp4$', re.IGNORECASE)
    for r in range(master.shape[0]):
        vid = to_str(master[r, video_col])
        m = pat.search(vid)
        if m and int(m.group(1)) == mp4_num:
            row_idx = r
            break
    if row_idx is None:
        raise ValueError(f"No MasterN row found whose video filename ends with _{mp4_num:02d}.mp4 "
                         f"(derived from file_id={file_id}).")

    # ---------- unwrap the 7x4 TimeStamps table for that row ----------
    ts_cell = master[row_idx, ts_col]
    ts_tbl = ts_cell
    while isinstance(ts_tbl, np.ndarray) and ts_tbl.dtype == object and ts_tbl.size == 1:
        ts_tbl = ts_tbl.item()
    ts_tbl = np.array(ts_tbl, dtype=object).squeeze()   # (7,4)

    def unquote(s):
        if isinstance(s, str):
            return s.strip("'")
        if isinstance(s, np.ndarray) and s.size == 1 and isinstance(s.item(), str):
            return s.item().strip("'")
        return str(s)

    def unwrap_2col(x):
        a = x
        while isinstance(a, np.ndarray) and a.dtype == object and a.size == 1:
            a = a.item()
        if isinstance(a, np.ndarray) and a.ndim == 2 and a.shape[1] >= 2 and a.dtype != object:
            return a[:, :2].astype(float)
        return np.empty((0,2), float)

    names = [unquote(ts_tbl[r,0]) for r in range(ts_tbl.shape[0])]
    try:
        i_frames = names.index("Frames")
    except ValueError:
        i_frames = [i for i,n in enumerate(names) if "Frames" in n][0]

    frames_td = unwrap_2col(ts_tbl[i_frames, 2])  # (n_frames, 2): [time, dt]
    frame_times_beh = frames_td[:,0].astype(float)
    fps = 1.0 / np.median(frames_td[:,1]) if np.all(frames_td[:,1] > 0) else 1.0/np.median(np.diff(frame_times_beh))

    # ---------- parameters ----------
    min_len_s  = 0.5
    min_len_fr = int(np.ceil(min_len_s * fps))

    default_buffers = {
        "Vis2": (0.75, 1.00),
        "Con" : (0.25, 0.50),
        "TOs2": (0.25, 0.25),
        "Lik" : (0.10, 0.10),
        "Rew" : (0.20, 0.50),
        "ENT" : (0.25, 0.25)
    }
    if buffers:
        default_buffers.update(buffers)

    # ---------- build busy intervals ----------
    busy, event_counts = [], {}
    for r, name in enumerate(names):
        if name == "Frames":
            continue
        td = unwrap_2col(ts_tbl[r, 2])  # [time, duration]
        if td.size == 0:
            continue
        starts = td[:,0]
        ends   = td[:,0] + td[:,1]
        pre, post = default_buffers.get(name, (0.10, 0.10))
        arr = np.column_stack([starts - pre, ends + post])
        busy.append(arr)
        event_counts[name] = len(arr)

    if not busy:
        spont_times = np.array([[frame_times_beh[0], frame_times_beh[-1]]], float)
        return _snap_and_filter(spont_times, frame_times_beh, output, min_len_fr)

    busy_merged = _merge_intervals(np.vstack(busy))

    # ---------- complement over movie span ----------
    t0, t1 = frame_times_beh[0], frame_times_beh[-1]
    spont_times = _complement_intervals(busy_merged, t0, t1)
    result = _snap_and_filter(spont_times, frame_times_beh, output, min_len_fr)

    if verbose:
        total_dur = frame_times_beh[-1] - frame_times_beh[0]
        spont_durs = np.diff(result, axis=1).squeeze()
        spont_sec = spont_durs / fps if output == "frames" else spont_durs
        n_spont = len(spont_sec)
        mean_spont = np.mean(spont_sec) if n_spont > 0 else 0
        max_spont = np.max(spont_sec) if n_spont > 0 else 0
        spont_total = np.sum(spont_sec)
        frac_spont = 100 * spont_total / total_dur

        print(f"[{data_path}] row={row_idx} file_id={file_id} (mp4 idx {mp4_num:02d}) fps≈{fps:.2f}")
        for name, count in event_counts.items():
            print(f"  {name:<6}: {count:5d} events")
        print(f"Spontaneous: {n_spont} intervals | mean {mean_spont:.2f}s | max {max_spont:.2f}s "
              f"| total {spont_total:.1f}s ({frac_spont:.1f}% of movie) "
              f"| min kept {min_len_s:.2f}s (~{min_len_fr} fr)")

    return result

# ---------- small utilities ----------
def _merge_intervals(iv):
    if iv.size == 0:
        return np.empty((0,2), float)
    iv = iv[np.argsort(iv[:,0])]
    out = []
    for s,e in iv:
        if not out or s > out[-1][1]:
            out.append([s,e])
        else:
            out[-1][1] = max(out[-1][1], e)
    return np.array(out, float)

def _complement_intervals(busy, t0, t1):
    if busy.size == 0:
        return np.array([[t0,t1]], float)
    gaps, cur = [], t0
    for s,e in busy:
        if s > cur: gaps.append([cur, s])
        cur = max(cur, e)
        if cur >= t1: break
    if cur < t1: gaps.append([cur, t1])
    return np.array(gaps, float) if gaps else np.empty((0,2), float)

def _snap_and_filter(spont_times, frame_times_beh, output, min_len_fr):
    ft = frame_times_beh
    start_idx = np.searchsorted(ft, spont_times[:,0], side="left")
    end_idx   = np.searchsorted(ft, spont_times[:,1], side="right") - 1
    start_idx = np.clip(start_idx, 0, len(ft)-1)
    end_idx   = np.clip(end_idx,   -1, len(ft)-1)

    keep = (end_idx - start_idx + 1) >= min_len_fr
    start_idx, end_idx = start_idx[keep], end_idx[keep]
    if start_idx.size == 0:
        return np.empty((0,2), int) if output=="frames" else np.empty((0,2), float)

    if output == "frames":
        return np.column_stack([start_idx, end_idx + 1]).astype(int)  # half-open
    return np.column_stack([ft[start_idx], ft[end_idx]]).astype(float)

In [None]:
def task_state_intervals(
    csv_path,
    date,                  # e.g., 240510 (int or str)
    file_id,               # e.g., 'meas00'
    short_run_threshold=3, # ≤ this many trials becomes "unclear" (internal runs only)
    verbose=False
):
    """
    Build engaged / attrition / unclear intervals from a TrialInfo CSV.

    Returns
    -------
    {'engaged': Nx2, 'attrition': Mx2, 'unclear': Kx2}  # times in seconds
    """
    df = pd.read_csv(csv_path)

    # Filter to Date & File (compare as strings to be forgiving about types)
    sub = df[(df['Date'].astype(str) == str(date)) & (df['File'].astype(str) == str(file_id))].copy()

    # Drop invalid trials and errors
    sub = sub[(sub['ValidTrial?'] == True) & (sub['TrialType'] != 'Error (0)')]

    if sub.empty:
        out = {'engaged': np.empty((0,2)), 'attrition': np.empty((0,2)), 'unclear': np.empty((0,2))}
        if verbose:
            print(f"No valid trials for Date={date}, File={file_id}.")
        return out

    # Extract time & state, sorted by time
    times = pd.to_numeric(sub['BFMTime']).to_numpy()
    attr  = sub['Attrition?'].astype(int).to_numpy()
    order = np.argsort(times)
    times, attr = times[order], attr[order]

    # Run-length encode Attrition?
    runs = []
    cur = attr[0]
    start = 0
    for i in range(1, len(attr)):
        if attr[i] != cur:
            runs.append((cur, start, i-1))  # (label, i0, i1)
            cur, start = attr[i], i
    runs.append((cur, start, len(attr)-1))

    # Relabel short INTERNAL runs (≤ threshold) as unclear (label 2)
    labeled = []
    for j, (lab, i0, i1) in enumerate(runs):
        length = i1 - i0 + 1
        is_internal = (j > 0) and (j < len(runs)-1)
        if is_internal and length <= short_run_threshold:
            labeled.append((2, i0, i1))  # unclear
        else:
            labeled.append((lab, i0, i1))  # 0 engaged, 1 attrition

    # Convert to [start_time, end_time] intervals
    buckets = {0: [], 1: [], 2: []}
    for lab, i0, i1 in labeled:
        t_start = times[i0]
        t_end   = times[i1]  # end at THIS run's last trial time (creates a gap to next run)
        if t_end > t_start:
            buckets[lab].append([t_start, t_end])

    out = {
        'engaged'  : np.array(buckets[0], dtype=float),
        'attrition': np.array(buckets[1], dtype=float),
        'unclear'  : np.array(buckets[2], dtype=float),
    }

    if verbose:
        def stats(arr):
            if len(arr) == 0:
                return (0, 0.0, 0.0, 0.0)
            d = arr[:,1] - arr[:,0]
            return (len(arr), d.sum(), d.mean(), d.max())
        nE,sE,mE,xE = stats(out['engaged'])
        nA,sA,mA,xA = stats(out['attrition'])
        nU,sU,mU,xU = stats(out['unclear'])
        print(f"[TrialInfo → intervals] Date={date}, File={file_id}")
        print(f"  Engaged  : {nE:3d} intervals | total {sE:7.2f}s | mean {mE:5.2f}s | max {xE:5.2f}s")
        print(f"  Attrition: {nA:3d} intervals | total {sA:7.2f}s | mean {mA:5.2f}s | max {xA:5.2f}s")
        print(f"  Unclear  : {nU:3d} intervals | total {sU:7.2f}s | mean {mU:5.2f}s | max {xU:5.2f}s")
        print(f"  Short-run threshold (internal only): ≤ {short_run_threshold} trials")

    return out

In [None]:
def classify_spontaneous_intervals(
    spont_times,           # Nx2 array of spontaneous intervals [t_start, t_end] (seconds)
    state_intervals,       # dict from task_state_intervals_from_trialinfo_simple
    verbose=False
):
    """
    Classify spontaneous intervals as 'engaged' or 'attrition'
    based on which task-state interval they fall inside.

    Assumes:
      - All spontaneous intervals are fully contained within (or outside) 
        the engaged/attrition intervals.
      - Engaged and attrition intervals are non-overlapping and disjoint.
      - All times are in the same (microscope) clock.
    """
    engaged_iv   = np.asarray(state_intervals.get('engaged',   []), float)
    attrition_iv = np.asarray(state_intervals.get('attrition', []), float)

    engaged_spont, attrition_spont = [], []

    for t0, t1 in spont_times:
        # Check if this spontaneous interval lies fully inside any engaged interval
        if np.any((t0 >= engaged_iv[:, 0]) & (t1 <= engaged_iv[:, 1])):
            engaged_spont.append([t0, t1])
        # Check if it's inside any attrition interval
        elif np.any((t0 >= attrition_iv[:, 0]) & (t1 <= attrition_iv[:, 1])):
            attrition_spont.append([t0, t1])
        # Otherwise ignore (falls outside both)
        else:
            continue

    engaged_spont   = np.array(engaged_spont, dtype=float)
    attrition_spont = np.array(attrition_spont, dtype=float)

    if verbose:
        def stats(arr):
            if len(arr) == 0: return (0, 0.0, 0.0, 0.0)
            d = arr[:,1] - arr[:,0]
            return (len(arr), d.sum(), d.mean(), d.max())
        nE,sE,mE,xE = stats(engaged_spont)
        nA,sA,mA,xA = stats(attrition_spont)
        print(f"[Spontaneous classification]")
        print(f"  Engaged  : {nE:3d} intervals | total {sE:7.2f}s | mean {mE:5.2f}s | max {xE:5.2f}s")
        print(f"  Attrition: {nA:3d} intervals | total {sA:7.2f}s | mean {mA:5.2f}s | max {xA:5.2f}s")

    return engaged_spont, attrition_spont

In [64]:
data_path = 'Z:/Voltage/VisualConsciousness/Analysis/VDT/cfm002mjr/20240510/BehaviorData.mat'
spont_times = spontaneous_intervals(data_path, file_id="meas01", output="time", verbose=True)

intervals = task_state_intervals(
    "trial_info/TrialInfo_cfm002mjr.csv",
    date=240510,
    file_id="meas01",
    short_run_threshold=3,
    verbose=True
)

# Classify spontaneous intervals by task state
engaged_spont, attrition_spont = classify_spontaneous_intervals(spont_times, intervals, verbose=True)

[Z:/Voltage/VisualConsciousness/Analysis/VDT/cfm002mjr/20240510/BehaviorData.mat] row=2 file_id=meas01 (mp4 idx 02) fps≈157.18
  Vis2  :    86 events
  Con   :    24 events
  TOs2  :    20 events
  Lik   :   220 events
  Rew   :    24 events
  ENT   :    13 events
Spontaneous: 92 intervals | mean 2.13s | max 5.22s | total 195.6s (42.4% of movie) | min kept 0.50s (~79 fr)
[TrialInfo → intervals] Date=240510, File=meas01
  Engaged  :   2 intervals | total  283.10s | mean 141.55s | max 240.26s
  Attrition:   1 intervals | total  144.49s | mean 144.49s | max 144.49s
  Unclear  :   0 intervals | total    0.00s | mean  0.00s | max  0.00s
  Short-run threshold (internal only): ≤ 3 trials
[Spontaneous classification]
  Engaged  :  51 intervals | total  112.21s | mean  2.20s | max  5.08s
  Attrition:  33 intervals | total   68.09s | mean  2.06s | max  5.22s
