In [None]:
# ðŸ§© TDT Reader Notebook
# This notebook explains and demonstrates the functionality of `tdt_reader.py`.
# It provides:
# - utilities to load and summarize TDT blocks,
# - auto-selection of stores (LFP, stim, epoc),
# - and helper functions for quick feature sanity checks.

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Iterable
import numpy as np
import tdt

BASELINE = (-0.200, 0.000)
RESPONSE = (0.0125, 0.100)
STIM_WIN = (-0.001, 0.005)

def _as_seconds_duration(d) -> float:
    return d.total_seconds() if hasattr(d, "total_seconds") else float(d)

def _first_match(keys: Iterable[str], preds: Iterable) -> Optional[str]:
    for k in keys:
        for p in preds:
            if p(k):
                return k
    return None

def _mean_across_channels(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x)
    if x.ndim == 2:
        return x.mean(axis=0)
    return x

def _slice_by_seconds(data: np.ndarray, fs: float, t0: float, t1: float) -> np.ndarray:
    i0 = max(0, int(np.floor(t0 * fs)))
    i1 = min(len(data), int(np.ceil(t1 * fs)))
    if i1 <= i0:
        return np.asarray([], dtype=float)
    return data[i0:i1]

def _rms(x: np.ndarray) -> float:
    x = np.asarray(x, dtype=float).ravel()
    return float(np.sqrt(np.mean(x**2))) if x.size else np.nan

@dataclass
class AutoStores:
    epoc: str
    lfp: str
    stim: Optional[str]

@dataclass
class StreamInfo:
    fs: float
    shape: Tuple[int, ...]
    scale: Optional[float]

@dataclass
class BlockSummary:
    path: str
    duration_sec: float
    n_streams: int
    n_epocs: int
    auto: AutoStores
    n_events: int
    stim_absmax_median: float
    streams: Dict[str, StreamInfo]
    epocs: Dict[str, Dict[str, int]]

def read_block(path: str):
    td = tdt.read_block(path)
    try:
        _ = getattr(td.info, "blockpath")
    except Exception:
        try:
            setattr(td.info, "blockpath", str(path))
        except Exception:
            pass
    return td

def auto_select_stores(tdt_obj) -> AutoStores:
    epoc_key = _first_match(
        tdt_obj.epocs.keys(),
        preds=[
            lambda k: k.lower().startswith("pc"),
            lambda k: k.lower().startswith("pt"),
            lambda k: k.lower().startswith("u"),
            lambda k: True,
        ],
    )
    if epoc_key is None:
        raise RuntimeError("No epoc stores found in block.")

    lfp_key = _first_match(
        tdt_obj.streams.keys(),
        preds=[lambda k: "lfp" in k.lower(), lambda k: "wav" in k.lower()],
    )
    if lfp_key is None:
        raise RuntimeError("No LFP/Wav-like stream found.")

    lower_streams = {k.lower(): k for k in tdt_obj.streams.keys()}
    stim_key = None
    for cand in ("izn1", "ssig", "sout"):
        if cand in lower_streams:
            stim_key = lower_streams[cand]
            break
    if stim_key is None:
        stim_key = _first_match(tdt_obj.streams.keys(), [lambda k: k.lower().startswith("izn")])

    return AutoStores(epoc=epoc_key, lfp=lfp_key, stim=stim_key)

def get_stream(tdt_obj, name: str) -> Tuple[float, np.ndarray, Optional[float]]:
    s = tdt_obj.streams[name]
    fs = float(getattr(s, "fs"))
    data = np.asarray(s.data)
    scale = getattr(s, "scale", None)
    return fs, data, scale

def get_event_onsets(tdt_obj, epoc_name: str) -> np.ndarray:
    return np.asarray(tdt_obj.epocs[epoc_name].onset, dtype=float)

def quick_summary(tdt_obj) -> dict:
    duration = _as_seconds_duration(getattr(tdt_obj.info, "duration", 0.0))
    auto = auto_select_stores(tdt_obj)

    streams = {}
    for k, s in tdt_obj.streams.items():
        streams[k] = dict(
            fs=float(getattr(s, "fs", np.nan)),
            shape=list(np.asarray(s.data).shape),
            scale=getattr(s, "scale", None),
        )

    epocs = {k: {"n": int(len(v.onset))} for k, v in tdt_obj.epocs.items()}

    onsets = get_event_onsets(tdt_obj, auto.epoc) if auto.epoc else np.array([])
    n_events = int(onsets.size)

    stim_absmax_median = 0.0
    if auto.stim is not None:
        fs_stim, stim, _ = get_stream(tdt_obj, auto.stim)
        vals = []
        for t0 in onsets:
            seg = _slice_by_seconds(_mean_across_channels(stim), fs_stim,
                                    t0 + STIM_WIN[0], t0 + STIM_WIN[1])
            if seg.size:
                vals.append(float(np.max(np.abs(seg))))
        stim_absmax_median = float(np.median(vals)) if vals else 0.0

    return dict(
        path=str(getattr(tdt_obj.info, "blockpath", "")) or "UNKNOWN",
        duration_sec=float(duration),
        n_streams=len(streams),
        n_epocs=len(epocs),
        auto=dict(epoc=auto.epoc.replace("/", "_"),
                  lfp=auto.lfp,
                  stim=auto.stim),
        n_events=n_events,
        stim_absmax_median=stim_absmax_median,
        streams=streams,
        epocs=epocs,
    )

def epoch_lfp(tdt_obj, onsets: np.ndarray, lfp_name: str, window: Tuple[float, float]) -> Tuple[np.ndarray, np.ndarray]:
    fs, data, _ = get_stream(tdt_obj, lfp_name)
    x = _mean_across_channels(data)
    pre, post = window
    n_samples = int(round((post - pre) * fs))
    t = np.linspace(pre, post, n_samples, endpoint=False)
    ep = []
    for t0 in onsets:
        i0 = int(round((t0 + pre) * fs))
        i1 = i0 + n_samples
        if i0 < 0 or i1 > len(x):
            continue
        e = x[i0:i1]
        if len(e) == n_samples:
            ep.append(e)
    return t, np.asarray(ep, dtype=float)
