# Spike-LFP Analysis (PSD, BLP, Spectrogram, Coupling, PAC)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.signal import welch, spectrogram, butter, filtfilt, hilbert, coherence
from tqdm.auto import tqdm
import spikeinterface.full as si

# Paths (edit as needed)
spikeglx_probe_folder = Path(r"Z:\Koji\Neuropixels\1818\1818_11202025_g0\1818_11202025_g0_imec0")
spike_dir = spikeglx_probe_folder / "kilosort4"
spike_times_path = spike_dir / "spike_seconds_adj.npy"
spike_clusters_path = spike_dir / "spike_clusters.npy"
unit_labels_path = spike_dir / ".." / "kilosort4qMetrics" / "templates._bc_unit_labels.tsv"
celltype_path = spike_dir / "unit_classification_rulebased.csv"
csd_memmap_path = spike_dir / "csd_tmp.float32.bin"  # created by CSD pipeline
spectrogram_meta = spikeglx_probe_folder / "spectrogram_fullsession_meta.npz"
spectrogram_memmap = spikeglx_probe_folder / "spectrogram_fullsession.dat"

In [None]:
# Load spikes and metadata
spike_times = np.load(spike_times_path)  # seconds
spike_clusters = np.load(spike_clusters_path)

# good units mask from qMetrics (labels: 1/2 good, 0/3 bad)
unit_labels = np.loadtxt(unit_labels_path, delimiter="\t", dtype=int)
good_units = unit_labels == 1

celltypes = {}
if celltype_path.exists():
    import pandas as pd
    df_ct = pd.read_csv(celltype_path)
    celltypes = dict(zip(df_ct["cluster_id"], df_ct["cell_type"]))

print(f"Spikes: {spike_times.size}, units: {np.unique(spike_clusters).size}, good units: {good_units.sum()}")

In [None]:
# Load LFP (imec0.lf) and optional CSD memmap
lfp_rec = si.read_spikeglx(spikeglx_probe_folder, stream_id="imec0.lf")
csd_rec = None
if csd_memmap_path.exists():
    from spikeinterface.core import NumpyRecording
    n_frames = lfp_rec.get_num_frames()
    n_channels = lfp_rec.get_num_channels()
    fs = lfp_rec.get_sampling_frequency()
    csd_mm = np.memmap(csd_memmap_path, dtype=np.float32, mode="r", shape=(n_frames, n_channels))
    csd_rec = NumpyRecording([csd_mm], sampling_frequency=fs)
    csd_rec.set_probe(lfp_rec.get_probe())
    csd_rec.set_channel_ids(lfp_rec.channel_ids)
print(lfp_rec)
if csd_rec:
    print("CSD memmap loaded")

In [None]:
# PSD and band-limited power
def compute_psd_welch_chunked(rec, nperseg, noverlap, max_f=200, duration_s=120):
    fs = rec.get_sampling_frequency()
    n_frames = min(int(duration_s * fs), rec.get_num_frames())
    traces = rec.get_traces(start_frame=0, end_frame=n_frames)
    f, Pxx = welch(traces, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=0)
    mask = f <= max_f
    return f[mask], Pxx[mask]

def compute_bandpower(rec, band, win_s=1.0, step_s=0.1, duration_s=120):
    fs = rec.get_sampling_frequency()
    n_frames_total = min(int(duration_s * fs), rec.get_num_frames())
    win = int(win_s * fs)
    step = max(1, int(step_s * fs))
    b, a = butter(4, [band[0] / (fs / 2), band[1] / (fs / 2)], btype="band")
    powers = []
    times = []
    for start in tqdm(range(0, n_frames_total - win + 1, step), desc=f"{band[0]}-{band[1]} Hz"):
        end = start + win
        chunk = rec.get_traces(start_frame=start, end_frame=end)
        filt = filtfilt(b, a, chunk, axis=0)
        p = np.mean(filt ** 2, axis=0)
        powers.append(p)
        times.append(start / fs)
    return np.asarray(times), np.vstack(powers)

# Example bandpower
bands = [(4,8), (13,30), (30,80)]
band_results = {}
for band in bands:
    times, pwr = compute_bandpower(lfp_rec, band=band, win_s=0.5, step_s=0.1, duration_s=120)
    band_results[band] = (times, pwr)
print("Bandpower computed for", bands)

In [None]:
# Spike-LFP coupling: phase locking and STA
def compute_phase_locking(spike_times_s, spike_clusters, good_units_mask, lfp_rec, band=(13,30)):
    fs = lfp_rec.get_sampling_frequency()
    b, a = butter(4, [band[0]/(fs/2), band[1]/(fs/2)], btype="band")
    phases = {}
    for unit_id, good in enumerate(good_units_mask):
        if not good:
            continue
        st = spike_times_s[spike_clusters == unit_id]
        if st.size == 0:
            continue
        ch = 0  # simple example; ideally choose channel nearest unit
        end_f = min(lfp_rec.get_num_frames(), int((st.max() + 1) * fs))
        lfp = lfp_rec.get_traces(start_frame=0, end_frame=end_f, channel_ids=[ch])[:,0]
        filt = filtfilt(b, a, lfp)
        phase = np.angle(hilbert(filt))
        spike_idx = (st * fs).astype(int)
        spike_idx = spike_idx[spike_idx < phase.size]
        phases[unit_id] = phase[spike_idx]
    return phases

def compute_sta(spike_times_s, spike_clusters, unit_id, lfp_rec, ch=0, window_s=0.1):
    fs = lfp_rec.get_sampling_frequency()
    half = int(window_s * fs / 2)
    st = spike_times_s[spike_clusters == unit_id]
    sta = []
    lfp = lfp_rec.get_traces(channel_ids=[ch]).flatten()
    for s in st:
        idx = int(s * fs)
        if idx - half < 0 or idx + half >= lfp.size:
            continue
        sta.append(lfp[idx - half: idx + half])
    if len(sta) == 0:
        return None
    return np.mean(np.vstack(sta), axis=0)

phases_beta = compute_phase_locking(spike_times, spike_clusters, good_units, lfp_rec, band=(13,30))
print(f"Phase locking computed for {len(phases_beta)} good units (beta band)")

In [None]:
# Spike-field coherence (simple)
def spike_field_coherence(unit_id, spike_times_s, spike_clusters, lfp_rec, ch=0, duration_s=120, nperseg=1024, noverlap=512, max_f=200):
    fs = lfp_rec.get_sampling_frequency()
    n_frames = min(int(duration_s * fs), lfp_rec.get_num_frames())
    st = spike_times_s[spike_clusters == unit_id]
    st = st[st < duration_s]
    spikes_bin = np.zeros(n_frames, dtype=float)
    idx = (st * fs).astype(int)
    idx = idx[idx < n_frames]
    spikes_bin[idx] = 1.0
    lfp = lfp_rec.get_traces(start_frame=0, end_frame=n_frames, channel_ids=[ch])[:, 0]
    f, Cxy = coherence(spikes_bin, lfp, fs=fs, nperseg=nperseg, noverlap=noverlap)
    mask = f <= max_f
    return f[mask], Cxy[mask]

first_good = np.where(good_units)[0][0] if good_units.any() else None
if first_good is not None:
    f_coh, coh = spike_field_coherence(first_good, spike_times, spike_clusters, lfp_rec, ch=0, duration_s=120)
    plt.figure(figsize=(6,3))
    plt.plot(f_coh, coh)
    plt.xlabel('Frequency (Hz)'); plt.ylabel('Coherence'); plt.xlim(0, 100)
    plt.title(f'Spike-field coherence unit {first_good} vs ch0')
    plt.tight_layout(); plt.show()
else:
    print("No good units available for coherence example")

In [None]:
# Field-field coherence
def field_field_coherence(lfp_rec, ch_a=0, ch_b=1, duration_s=120, nperseg=2048, noverlap=1024, max_f=200):
    fs = lfp_rec.get_sampling_frequency()
    n_frames = min(int(duration_s * fs), lfp_rec.get_num_frames())
    lfp = lfp_rec.get_traces(start_frame=0, end_frame=n_frames, channel_ids=[ch_a, ch_b])
    f, Cxy = coherence(lfp[:,0], lfp[:,1], fs=fs, nperseg=nperseg, noverlap=noverlap)
    mask = f <= max_f
    return f[mask], Cxy[mask]

f_ff, coh_ff = field_field_coherence(lfp_rec, ch_a=0, ch_b=10, duration_s=120)
plt.figure(figsize=(6,3))
plt.plot(f_ff, coh_ff)
plt.xlabel('Frequency (Hz)'); plt.ylabel('Coherence'); plt.xlim(0, 100)
plt.title('Field-field coherence ch0 vs ch10')
plt.tight_layout(); plt.show()

In [None]:
# PAC (single band) and comodulogram
def compute_pac(lfp_rec, phase_band=(13,30), amp_band=(30,80), duration_s=60):
    fs = lfp_rec.get_sampling_frequency()
    n_frames = min(int(duration_s * fs), lfp_rec.get_num_frames())
    lfp = lfp_rec.get_traces(start_frame=0, end_frame=n_frames)
    b_p, a_p = butter(4, [phase_band[0]/(fs/2), phase_band[1]/(fs/2)], btype="band")
    b_a, a_a = butter(4, [amp_band[0]/(fs/2), amp_band[1]/(fs/2)], btype="band")
    phase = np.angle(hilbert(filtfilt(b_p, a_p, lfp, axis=0)))
    amp = np.abs(hilbert(filtfilt(b_a, a_a, lfp, axis=0)))
    nbins = 18
    bins = np.linspace(-np.pi, np.pi, nbins+1)
    mi = np.zeros(lfp.shape[1])
    for ch in range(lfp.shape[1]):
        digitized = np.digitize(phase[:, ch], bins) - 1
        mean_amp = np.array([amp[digitized == b, ch].mean() for b in range(nbins)])
        mean_amp /= mean_amp.sum()
        mi[ch] = (np.log(nbins) + np.sum(mean_amp * np.log(mean_amp + 1e-12))) / np.log(nbins)
    return mi

def pac_comodulogram(lfp_rec, phase_bands, amp_bands, duration_s=60):
    fs = lfp_rec.get_sampling_frequency()
    n_frames = min(int(duration_s * fs), lfp_rec.get_num_frames())
    lfp = lfp_rec.get_traces(start_frame=0, end_frame=n_frames)
    nb = len(phase_bands)
    na = len(amp_bands)
    mi = np.zeros((lfp.shape[1], nb, na))
    for ip, pb in enumerate(phase_bands):
        b_p, a_p = butter(4, [pb[0]/(fs/2), pb[1]/(fs/2)], btype='band')
        phase = np.angle(hilbert(filtfilt(b_p, a_p, lfp, axis=0)))
        for ia, ab in enumerate(amp_bands):
            b_a, a_a = butter(4, [ab[0]/(fs/2), ab[1]/(fs/2)], btype='band')
            amp = np.abs(hilbert(filtfilt(b_a, a_a, lfp, axis=0)))
            nbins = 18
            bins = np.linspace(-np.pi, np.pi, nbins+1)
            for ch in range(lfp.shape[1]):
                digitized = np.digitize(phase[:, ch], bins) - 1
                mean_amp = np.array([amp[digitized == b, ch].mean() for b in range(nbins)])
                mean_amp /= mean_amp.sum()
                mi[ch, ip, ia] = (np.log(nbins) + np.sum(mean_amp * np.log(mean_amp + 1e-12))) / np.log(nbins)
    return mi

pac_beta_gamma = compute_pac(lfp_rec, phase_band=(13,30), amp_band=(30,80), duration_s=60)
print("PAC beta-gamma (MI) computed for", pac_beta_gamma.size, "channels")

phase_bands = [(4,8), (13,30)]
amp_bands = [(30,80), (80,150)]
mi_grid = pac_comodulogram(lfp_rec, phase_bands, amp_bands, duration_s=30)
plt.figure(figsize=(5,4))
plt.imshow(mi_grid[0], origin='lower', aspect='auto', cmap='magma',
           extent=[0, len(amp_bands), 0, len(phase_bands)])
plt.xticks(np.arange(len(amp_bands)) + 0.5, [f"{a[0]}-{a[1]}" for a in amp_bands])
plt.yticks(np.arange(len(phase_bands)) + 0.5, [f"{p[0]}-{p[1]}" for p in phase_bands])
plt.colorbar(label='MI')
plt.title('PAC (ch0)')
plt.tight_layout(); plt.show()

In [None]:
# Replay detection skeleton (fill in with behavior alignment)
# Steps:
# 1) Build spatial tuning curves from DLC-derived position during movement.
# 2) Detect candidate quiescent/ripple events (e.g., bandpower threshold, low speed).
# 3) Decode position in events (Bayesian / HMM) and test against shuffled spikes/fields.
# 4) Assess phase precession (beta) across decoded trajectories.
print("Replay skeleton: implement after behavior alignment is in place.")