In [8]:
from pathlib import Path
import shutil
from tqdm import tqdm
import numpy as np
import os
os.environ["NUMEXPR_MAX_THREADS"] = "64"  # or set to "128" if you want to try using all cores

from kilosort import run_kilosort, DEFAULT_SETTINGS
from kilosort.io import load_probe, save_preprocessing, load_ops

def kilosort_h2p(save_path,probe_path,drift_correction_type='none',probe_type='np',save_tw=False):

    SAVE_PATH = Path(save_path)
    PROBE_PATH = Path(probe_path)
    
    probe = load_probe(probe_path)

    settings = DEFAULT_SETTINGS
    settings['probe'] = probe
    settings['n_chan_bin'] = probe['n_chan']+1
    settings['fs'] = 30000

    if probe_type == 'np':
        if drift_correction_type == 'kilosort':
            settings['nblocks'] = 1
            kilosort_path = SAVE_PATH.parent / 'kilosort4_kilosort'
        elif drift_correction_type == 'medicine':
            settings['nblocks'] = 2
            kilosort_path = SAVE_PATH.parent / 'kilosort4_medicine'
        else:
            kilosort_path = SAVE_PATH.parent / 'kilosort4'
            settings['nblocks']              =  0
            settings['Th_universal']         =  8    # 9
            settings['Th_learned']           =  8    # 8
            settings['dminx']                =  103 
            settings['nearest_chans']        =  1
            settings['max_channel_distance'] =  103 
            settings['acg_threshold']        =  0.01
            settings['ccg_threshold']        =  0 
            settings['duplicate_spike_ms']   =  0
            
    settings['data_dir'] = SAVE_PATH.parent

    ########## temporary ############
    settings['tmin'] = 0
    settings['tmax'] = 2400

    if probe_type == 'plex': 
        settings['batch_size'] = 60000*2 # 60000
        settings['nblocks'] = 0 
        settings['Th_universal'] = 9 # 9
        settings['Th_learned'] = 7   # 8

        settings['min_template_size'] = 10 #10
        settings['nearest_templates'] = 23 # 100
    
    if kilosort_path.exists() and kilosort_path.is_dir():
        shutil.rmtree(kilosort_path)

    if save_tw:
        ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate, kept_spikes = run_kilosort(settings=settings, probe=probe, data_dtype='int16', results_dir=kilosort_path, save_preprocessed_copy=True, drift_correction_type=drift_correction_type)
    else:
        ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate, kept_spikes = run_kilosort(settings=settings, probe=probe, data_dtype='int16', results_dir=kilosort_path, save_preprocessed_copy=False, drift_correction_type=drift_correction_type)
   
    print('%%%%%%%%%%%%%%% KILOSORT DONE RUNNING %%%%%%%%%%%%%%%')


kilosort4_0: 694 units, 57 good units
    settings['nblocks']              =  0
    settings['Th_universal']         =  9    # 9
    settings['Th_learned']           =  8    # 8
    settings['dminx']                =  103 
    settings['nearest_chans']        =  5
    settings['max_channel_distance'] =  103 
    settings['acg_threshold']        =  0.01
    settings['ccg_threshold']        =  0 
    settings['duplicate_spike_ms']   =  0
    settings['tmin'] = 0
    settings['tmax'] = 600

kilosort4_1: 763 units, 47 good units
    settings['nblocks']              =  0
    settings['Th_universal']         =  9    # 9
    settings['Th_learned']           =  8    # 8
    settings['dminx']                =  103 
    settings['nearest_chans']        =  5
    settings['max_channel_distance'] =  103 
    settings['acg_threshold']        =  0.01
    settings['ccg_threshold']        =  0 
    settings['duplicate_spike_ms']   =  0
    settings['tmin'] = 0
    settings['tmax'] = **2400**

kilosort4_2: 999 units, 73 good units
    settings['nblocks']              =  0
    settings['Th_universal']         =  9    # 9
    settings['Th_learned']           =  8    # 8
    settings['dminx']                =  103 
    settings['nearest_chans']        =  **3**
    settings['max_channel_distance'] =  103 
    settings['acg_threshold']        =  0.01
    settings['ccg_threshold']        =  0 
    settings['duplicate_spike_ms']   =  0
    settings['tmin'] = 0
    settings['tmax'] = 2400

kilosort4_3: 1003 units, 77 good units
    settings['nblocks']              =  0
    settings['Th_universal']         =  9    # 9
    settings['Th_learned']           =  8    # 8
    settings['dminx']                =  103 
    settings['nearest_chans']        =  **1**
    settings['max_channel_distance'] =  103 
    settings['acg_threshold']        =  0.01
    settings['ccg_threshold']        =  0 
    settings['duplicate_spike_ms']   =  0
    settings['tmin'] = 0
    settings['tmax'] = 2400

kilosort4_4: 1003 units, 77 good units
    settings['nblocks']              =  0
    settings['Th_universal']         =  **8**    # 9
    settings['Th_learned']           =  8    # 8
    settings['dminx']                =  103 
    settings['nearest_chans']        =  1
    settings['max_channel_distance'] =  103 
    settings['acg_threshold']        =  0.01
    settings['ccg_threshold']        =  0 
    settings['duplicate_spike_ms']   =  0
    settings['tmin'] = 0
    settings['tmax'] = 2400

In [None]:
SESSION = 'kendra_scrappy_0138a'
IMEC_NUM = 1

SAVE_PATH = Path(f'/ix1/pmayo/lab_NHPdata/{SESSION}_g0/{SESSION}_g0_imec{IMEC_NUM}/{SESSION}_g0_t0.imec{IMEC_NUM}.ap.bin')
PROBE_PATH = Path(f'/ix1/pmayo/lab_NHPdata/{SESSION}_g0/{SESSION}_g0_imec{IMEC_NUM}/{SESSION}_g0_t0.imec{IMEC_NUM}.ap_kilosortChanMap.mat')

kilosort_h2p(SAVE_PATH,PROBE_PATH,drift_correction_type='none',probe_type='np',save_tw=False)

kilosort.run_kilosort: Kilosort version 0.1.dev1164+gdfaaf6b.d20240502
[INFO] - Kilosort version 0.1.dev1164+gdfaaf6b.d20240502
kilosort.run_kilosort: Python version 3.10.14
[INFO] - Python version 3.10.14
kilosort.run_kilosort: ----------------------------------------
[INFO] - ----------------------------------------
kilosort.run_kilosort: System information:
[INFO] - System information:
kilosort.run_kilosort: Linux-3.10.0-1160.71.1.el7.x86_64-x86_64-with-glibc2.17 x86_64
[INFO] - Linux-3.10.0-1160.71.1.el7.x86_64-x86_64-with-glibc2.17 x86_64
kilosort.run_kilosort: x86_64
[INFO] - x86_64
kilosort.run_kilosort: Using GPU for PyTorch computations. Specify `device` to change this.
[INFO] - Using GPU for PyTorch computations. Specify `device` to change this.
kilosort.run_kilosort: Using CUDA device: NVIDIA A100-SXM4-80GB 79.14GB
[INFO] - Using CUDA device: NVIDIA A100-SXM4-80GB 79.14GB
kilosort.run_kilosort: ----------------------------------------
[INFO] - -------------------------------

In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from kilosort.io import load_ops
from kilosort.data_tools import (
    mean_waveform, cluster_templates, get_good_cluster, get_cluster_spikes,
    get_spike_waveforms, get_best_channels
    )


# Indicate where sorting results were saved
results_dir = Path(f'/ix1/pmayo/lab_NHPdata/{SESSION}_g0/{SESSION}_g0_imec{IMEC_NUM}/kilosort4')

# Pick a random good cluster
cluster_id = get_good_cluster(results_dir, n=1)

# Get the mean spike waveform and mean templates for the cluster
mean_wv, spike_subset = mean_waveform(cluster_id, results_dir, n_spikes=100,
                                      bfile=None, best=True)
mean_temp = cluster_templates(cluster_id, results_dir, mean=True,
                              best=True, spike_subset=spike_subset)

# Get time in ms for visualization
ops = load_ops(results_dir / 'ops.npy')
t = (np.arange(ops['nt']) / ops['fs']) * 1000

fig, ax = plt.subplots(1,1)
ax.plot(t, mean_wv, c='black', linestyle='dashed', linewidth=2, label='waveform')
ax.plot(t, mean_temp, linewidth=1, label='template')
ax.set_title(f'Mean single-channel template and spike waveform for cluster {cluster_id}')
ax.set_xlabel('Time (ms)')
ax.legend()