# Replay, HMM on LFP, and Joint Spike–LFP–Behavior Features

Skeleton implementations for the remaining TODO items:
- Decode replay with place fields + shuffles
- HMM on LFP bandpower
- Joint spike–LFP–behavior feature extraction and clustering
- Positional clustering of units
- Real-time stimulation outline

Edit paths below to your dataset.

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.signal import butter, filtfilt
from tqdm.auto import tqdm
from sklearn.mixture import GaussianMixture
try:
    from hmmlearn import hmm
    have_hmm = True
except ImportError:
    have_hmm = False
    print("hmmlearn not installed; HMM cell will error unless installed.")

# Paths (edit)
spikeglx_probe_folder = Path(r"Z:\Koji\Neuropixels\1818\1818_11202025_g0\1818_11202025_g0_imec0")
spike_dir = spikeglx_probe_folder / "kilosort4"
spike_times_path = spike_dir / "spike_seconds_adj.npy"
spike_clusters_path = spike_dir / "spike_clusters.npy"
unit_labels_path = spike_dir / ".." / "kilosort4qMetrics" / "templates._bc_unit_labels.tsv"
celltype_path = spike_dir / "unit_classification_rulebased.csv"
spectrogram_meta = spikeglx_probe_folder / "spectrogram_fullsession_meta.npz"
spectrogram_memmap = spikeglx_probe_folder / "spectrogram_fullsession.dat"

# Behavior paths (edit if needed)
event_csvs = [
    Path(r"Z:\Koji\NP_Coh3\Recording\Day27_1818_Clockwise_corner_2025-11-20T15_11_30.csv"),
    Path(r"Z:\Koji\NP_Coh3\Recording\Day27_1818_Clockwise_licking_2025-11-20T15_11_30.csv"),
]
dlc_csv = Path(r"Z:\Koji\NP_Coh3\Recording\Day27_1818_Clockwise2025-11-20T15_50_10DLC_HrnetW32_openfield_v3Sep10shuffle2_detector_170_snapshot_160.csv")
cam_fps = 60.0

print("Paths set. Edit as needed.")

In [None]:
# Load spikes and metadata
spike_times = np.load(spike_times_path)
spike_clusters = np.load(spike_clusters_path)
unit_labels = np.loadtxt(unit_labels_path, delimiter="\t", dtype=int)
good_units = unit_labels == 1
celltypes = {}
if celltype_path.exists():
    import pandas as pd
    df_ct = pd.read_csv(celltype_path)
    celltypes = dict(zip(df_ct["cluster_id"], df_ct["cell_type"]))
print(f"Spikes: {spike_times.size}, units: {np.unique(spike_clusters).size}, good units: {good_units.sum()}")

In [None]:
# Load behavior (DLC) and events
dlc = pd.read_csv(dlc_csv, header=None)
dlc_time_cam = np.arange(len(dlc)) / cam_fps
events = pd.concat([pd.read_csv(p) for p in event_csvs], ignore_index=True)
print("DLC shape:", dlc.shape, "Events shape:", events.shape)

## Alignment helper (camera time to NP time)
Reuse the mapping from `Event_DLC_DA_alignment.ipynb` if available. Set `a_cam_np` and `b_cam_np` accordingly.

In [None]:
# Set mapping from camera time to NP time (t_np = a*t_cam + b)
a_cam_np = None  # fill from alignment notebook
b_cam_np = None

def cam_to_np_time(cam_times):
    if a_cam_np is None or b_cam_np is None:
        raise RuntimeError("Set a_cam_np and b_cam_np from alignment notebook")
    return a_cam_np * np.asarray(cam_times) + b_cam_np

dlc_time_np = cam_to_np_time(dlc_time_cam) if a_cam_np is not None else None

## Place fields and replay decoding (Bayesian skeleton)
- Compute 1D position from DLC (choose a body part / axis)
- Build tuning curves during movement
- Decode position in candidate events and test vs shuffles

In [None]:
def compute_tuning(pos, pos_time, spikes_s, spikes_clu, good_mask, nbins=50, speed=None, speed_thresh=2.0):
    pos = np.asarray(pos)
    pos_time = np.asarray(pos_time)
    edges = np.linspace(pos.min(), pos.max(), nbins+1)
    occ, _ = np.histogram(pos, bins=edges)
    occ = occ / np.diff(pos_time).mean()  # occupancy in samples -> time
    tc = {}
    centers = 0.5 * (edges[:-1] + edges[1:])
    for uid, good in enumerate(good_mask):
        if not good:
            continue
        st = spikes_s[spikes_clu == uid]
        if speed is not None:
            # keep spikes during movement
            # approximate: mask by nearest position sample
            st_mask = speed[np.clip((st / np.diff(pos_time).mean()).astype(int), 0, speed.size-1)] > speed_thresh
            st = st[st_mask]
        if st.size == 0:
            continue
        # bin spikes by position via interpolation
        pos_at_spikes = np.interp(st, pos_time, pos)
        spk_counts, _ = np.histogram(pos_at_spikes, bins=edges)
        tc[uid] = (spk_counts + 1e-3) / (occ + 1e-3)  # FR estimate
    return centers, tc

def decode_position(tuning_curves, pos_bins, spike_bin_counts):
    # spike_bin_counts: shape (n_units, n_timebins)
    units = list(tuning_curves.keys())
    lam = np.stack([tuning_curves[u] for u in units])  # (n_units, nbins)
    dt = 0.02  # bin size (s) assumed
    # Poisson likelihood
    log_l = spike_bin_counts.T[:,:,None] * np.log(lam.T[None,:,:] + 1e-12) - dt * lam.T[None,:,:]
    post = np.exp(log_l.sum(axis=1))
    post /= post.sum(axis=1, keepdims=True)
    return post  # (n_timebins, nbins)

# Placeholder usage:
# pos = dlc.iloc[:,0].values (choose appropriate body part/axis)
# centers, tc = compute_tuning(pos, dlc_time_np, spike_times, spike_clusters, good_units)
# Build spike counts in small bins during events, then decode_position(tc, centers, spike_counts)

## HMM on LFP bandpower
- Extract bandpower features per channel (or PCA of channels)
- Fit a Gaussian HMM to uncover latent LFP states

In [None]:
def lfp_bandpower_features(lfp_rec, bands, win_s=0.5, step_s=0.1, duration_s=300):
    fs = lfp_rec.get_sampling_frequency()
    n_frames = min(int(duration_s * fs), lfp_rec.get_num_frames())
    win = int(win_s * fs)
    step = max(1, int(step_s * fs))
    feats = []
    times = []
    bs = [butter(4, [b[0]/(fs/2), b[1]/(fs/2)], btype='band') for b in bands]
    for start in tqdm(range(0, n_frames - win + 1, step), desc="LFP bandpower feats"):
        end = start + win
        x = lfp_rec.get_traces(start_frame=start, end_frame=end)
        band_pwr = []
        for b,a in bs:
            xf = filtfilt(b,a,x,axis=0)
            band_pwr.append(np.mean(xf**2, axis=0))
        feats.append(np.hstack(band_pwr))
        times.append(start/fs)
    return np.asarray(times), np.vstack(feats)

def fit_hmm_bandpower(features, n_states=3):
    if not have_hmm:
        raise RuntimeError("hmmlearn not installed")
    model = hmm.GaussianHMM(n_components=n_states, covariance_type='full', n_iter=100)
    model.fit(features)
    states = model.predict(features)
    return model, states

# Example (uncomment when ready):
# bands = [(4,8),(13,30),(30,80)]
# times_feat, feats = lfp_bandpower_features(lfp_rec, bands, win_s=0.5, step_s=0.1, duration_s=120)
# model, states = fit_hmm_bandpower(feats, n_states=3)
# plt.plot(times_feat, states); plt.title('LFP HMM states'); plt.show()

## Spike–LFP–behavior joint features
- Bin spikes, LFP bandpower, and DLC kinematics into a feature matrix
- Cluster or classify to find motifs

In [None]:
def build_joint_features(spike_times_s, spike_clusters, good_units, lfp_rec, dlc_pos, dlc_time_np, bands, bin_s=0.05, duration_s=300):
    fs = lfp_rec.get_sampling_frequency()
    nbins = int(duration_s / bin_s)
    t_bins = np.linspace(0, duration_s, nbins+1)
    # spike bins (good units only)
    units = np.where(good_units)[0]
    spk_mat = np.zeros((nbins, len(units)))
    for ui, u in enumerate(units):
        st = spike_times_s[spike_clusters == u]
        st = st[(st >= 0) & (st < duration_s)]
        counts, _ = np.histogram(st, bins=t_bins)
        spk_mat[:, ui] = counts / bin_s
    # LFP bandpower per bin for channel 0 (extend as needed)
    lfp = lfp_rec.get_traces(start_frame=0, end_frame=min(int(duration_s*fs), lfp_rec.get_num_frames()), channel_ids=[0])[:,0]
    band_feats = []
    for b in bands:
        b_f, a_f = butter(4, [b[0]/(fs/2), b[1]/(fs/2)], btype='band')
        lfpf = filtfilt(b_f, a_f, lfp)
        amp = lfpf**2
        # bin by time
        t_lfp = np.arange(len(amp))/fs
        bp, _ = np.histogram(t_lfp, bins=t_bins, weights=amp)
        counts, _ = np.histogram(t_lfp, bins=t_bins)
        bp = bp / (counts + 1e-6)
        band_feats.append(bp)
    band_feats = np.vstack(band_feats).T  # (nbins, nbands)
    # Behavior: position and speed from DLC
    pos_interp = np.interp((t_bins[:-1]+t_bins[1:])/2, dlc_time_np, dlc_pos)
    speed = np.concatenate([[0], np.abs(np.diff(pos_interp)) / bin_s])
    beh_feats = np.vstack([pos_interp, speed]).T
    # Concatenate
    feats = np.hstack([spk_mat, band_feats, beh_feats])
    return t_bins[:-1], feats, units

def cluster_features(feats, n_clusters=5):
    gm = GaussianMixture(n_components=n_clusters, covariance_type='full', n_init=3)
    labels = gm.fit_predict(feats)
    return gm, labels

# Example usage (requires dlc_time_np set):
# t_feat, feats, units = build_joint_features(spike_times, spike_clusters, good_units, lfp_rec, dlc.iloc[:,0].values, dlc_time_np, bands=[(13,30),(30,80)], bin_s=0.05, duration_s=120)
# gm, labels = cluster_features(feats, n_clusters=5)
# plt.plot(t_feat, labels); plt.title('Joint feature clusters'); plt.show()

## Positional clustering of units
- Cluster units based on their place-field tuning (correlations)

In [None]:
def cluster_place_fields(tuning_curves, n_clusters=4):
    # tuning_curves: dict unit-> FR per bin
    units = list(tuning_curves.keys())
    if len(units) == 0:
        return None, None
    mat = np.stack([tuning_curves[u] for u in units])
    # normalize
    mat = (mat - mat.mean(axis=1, keepdims=True)) / (mat.std(axis=1, keepdims=True)+1e-6)
    gm = GaussianMixture(n_components=n_clusters, covariance_type='full', n_init=3)
    labels = gm.fit_predict(mat)
    return dict(zip(units, labels)), gm

# Example: after compute_tuning -> labels = cluster_place_fields(tc, n_clusters=4)

## Real-time stimulation outline (concept)
- Implemented outside notebook in practice; here is pseudocode outline.
- Steps:
  1) Stream spikes from AP band (threshold/template) with low latency.
  2) Detect pattern (e.g., high firing of specific units or phase-locked bursts).
  3) Trigger DAQ line for optogenetic stim with safety limits.
  4) Start with DAT-Cre for prototyping, then adapt to AnxA1-Cre.
- Use a compiled/real-time environment (e.g., Open Ephys plugins or custom C++/Python with NI-DAQ).