In [2]:
#!/usr/bin/env python3
"""
compute_waveforms_and_metrics.py

Given your core_dir and list of session names, this will:
  1. Load the behavioral (and, if present, tagging) sorting_analyzers
  2. For each unit, extract median & mean waveforms via your waveformPrepare_template()
  3. Compute template_metrics, correlograms & isi_histograms
  4. Save everything as .npy under each session’s kilosort4 folder
"""

import os
import sys
import pathlib
sys.path.append(str(pathlib.Path.cwd() / "pipeline" / "spikeinterface_waveform_extraction"))

import numpy as np
from spikeinterface_v2 import load_sorting_analyzer
from my_spike_tools import waveformPrepare_template  # adjust import as needed

# 1) Configuration — point these at your setup
core_dir = r"Z:\Koji\Neuropixels\1818"
date_strs = [
    '11202025'
]
session_names = [f"1818_{d}_g0" for d in date_strs]

for session in session_names:
    sel = os.path.join(core_dir, session, f"{session}_imec0")
    kilo_dir = os.path.join(sel, "kilosort4")

    # 2) load analyzers
    beh_folder = os.path.join(sel, "analyzer_beh")
    beh_an = load_sorting_analyzer(folder=beh_folder, format="binary_folder")

    tag_folder = os.path.join(sel, "analyzer_tag")
    has_tag = os.path.isdir(tag_folder)
    if has_tag:
        tag_an = load_sorting_analyzer(folder=tag_folder, format="binary_folder")

    # get all unit IDs
    unit_ids = beh_an.sorting.unit_ids

    # prepare containers
    wf_beh_med, wf_beh_avg = {}, {}
    wf_tag_med, wf_tag_avg = {}, {}

    # 3) extract waveforms
    for uid in unit_ids:
        avg, med, _ = waveformPrepare_template(
            beh_an, unit_id=uid, designated=False, use_template=True
        )
        wf_beh_avg[uid] = avg
        wf_beh_med[uid] = med

        if has_tag:
            avg_t, med_t, _ = waveformPrepare_template(
                tag_an, unit_id=uid, designated=True, use_template=False
            )
            wf_tag_avg[uid] = avg_t
            wf_tag_med[uid] = med_t

    # compute and save metrics
    for ext in ("template_metrics","correlograms","isi_histograms"):
        os.makedirs(os.path.join(beh_folder, "extensions", ext), exist_ok=True)
    tm  = beh_an.compute(
        input="template_metrics", include_multi_channel_metrics=True, save=False
    )
    ccg = beh_an.compute(
        input="correlograms", window_ms=1000.0, bin_ms=5.0, method="auto", save=False
    )
    isi = beh_an.compute(
        input="isi_histograms", window_ms=1000.0, bin_ms=5.0, method="auto", save=False
    )

    # grab your unit IDs
    unit_ids = beh_an.sorting.unit_ids

    # convert metrics → dicts by unit_id
    # Convert template metrics to a dictionary manually
    ccg_dict = { uid: ccg.get_data()[0][i, i]  for i, uid in enumerate(unit_ids) }
    isi_dict = { uid: isi.get_data()[0][i]     for i, uid in enumerate(unit_ids) }

    # save metrics
    tm.get_data().to_csv(os.path.join(kilo_dir, "template_metrics.csv"))
    np.save(os.path.join(kilo_dir, "ACC.npy"),      ccg_dict)
    np.save(os.path.join(kilo_dir, "ISI.npy"),    isi_dict)

    # save waveforms
    np.save(os.path.join(kilo_dir, "waveform_beh_average.npy"), wf_beh_avg)
    np.save(os.path.join(kilo_dir, "waveform_beh_median.npy"),  wf_beh_med)
    if has_tag:
        np.save(os.path.join(kilo_dir, "waveform_tag_average.npy"), wf_tag_avg)
        np.save(os.path.join(kilo_dir, "waveform_tag_median.npy"),  wf_tag_med)

    print(f"[{session}] done.")

Number of spikes for unit 1: 240
Using designated_templates for potential fallback template computation if needed.
Number of spikes for unit 1: 231
Calculating average and median from individual spikes for unit 1...
Number of spikes for unit 12: 1000
Using designated_templates for potential fallback template computation if needed.
Number of spikes for unit 12: 376
Calculating average and median from individual spikes for unit 12...
Number of spikes for unit 13: 455
Using designated_templates for potential fallback template computation if needed.
Number of spikes for unit 13: 3
Calculating average and median from individual spikes for unit 13...
Number of spikes for unit 15: 1000
Using designated_templates for potential fallback template computation if needed.
Number of spikes for unit 15: 4
Calculating average and median from individual spikes for unit 15...
Number of spikes for unit 16: 1000
Using designated_templates for potential fallback template computation if needed.
Number of sp

  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  ptratio = template_single[peak_idx] / template_single[trough_idx]
  MM = MM / np.max(MM)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
  slope = ssxym / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + 

[1818_11202025_g0] done.
