In [2]:
# Import important packages and functions
import os
import subprocess
from pathlib import Path
import numpy as np

In [3]:
tools_path = r'D:\NPX_data' # path to the folder holding all the "tools": CatGT, TPrime, etc
input_path = r'D:\NPX_data\DATA\08272024-NTBY4_g0' # path where the recording file is: this needs to be edited
output_path = r'D:\NPX_data\CatGT_OUT' # path to spit out the output of CatGT

base_path = Path('D:/NPX_data/DATA/08272024-NTBY4_g0') 
run_name = '08272024-NTBY4_g0'

# First load the spike_times from KS output

spikeSort_path = base_path / run_name / 'kilosort_output'

In [4]:
spikeSort_path

WindowsPath('D:/NPX_data/DATA/08272024-NTBY4_g0/08272024-NTBY4_g0/kilosort_output')

In [5]:
# Here we are loading in the .npy file from kilosort and then converting this into Hz instead from sample rate
spikeTimes_path = spikeSort_path / 'spike_times.npy'
spikeTimes = np.load(spikeTimes_path)
spikeTimes

array([[      469],
       [      953],
       [     1773],
       ...,
       [136157127],
       [136157168],
       [136157172]], dtype=uint64)

In [6]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tkinter import Tk
from tkinter import filedialog


# Parse ini file returning a dictionary whose keys are the metadata
# left-hand-side-tags, and values are string versions of the right-hand-side
# metadata values. We remove any leading '~' characters in the tags to match
# the MATLAB version of readMeta.
#
# The string values are converted to numbers using the "int" and "float"
# functions. Note that python 3 has no size limit for integers.
#
def readMeta(binFullPath):
    metaName = binFullPath.stem + ".meta"
    metaPath = Path(binFullPath.parent / metaName)
    metaDict = {}
    if metaPath.exists():
        # print("meta file present")
        with metaPath.open() as f:
            mdatList = f.read().splitlines()
            # convert the list entries into key value pairs
            for m in mdatList:
                csList = m.split(sep='=')
                if csList[0][0] == '~':
                    currKey = csList[0][1:len(csList[0])]
                else:
                    currKey = csList[0]
                metaDict.update({currKey: csList[1]})
    else:
        print("no meta file")
    return(metaDict)

# Return sample rate as python float.
# On most systems, this will be implemented as C++ double.
# Use python command sys.float_info to get properties of float on your system.
#
def SampRate(meta):
    if meta['typeThis'] == 'imec':
        srate = float(meta['imSampRate'])
    else:
        srate = float(meta['niSampRate'])
    return(srate)

In [7]:
# Get file from user
root = Tk()         # create the Tkinter widget
root.withdraw()     # hide the Tkinter root window

# Windows specific; forces the window to appear in front
root.attributes("-topmost", True)

binFullPath = Path(filedialog.askopenfilename(title="Select binary file"))

root.destroy()

In [8]:
# This spits out the sample rate from the meta file
meta = readMeta(binFullPath)
sRate = SampRate(meta)
sRate

29999.39932004327

In [9]:
# This converts the sample rate times in spike_times.npy to the spike times in seconds
spikeTimes = np.load(spikeSort_path / 'spike_times.npy')
spikeTimes_adjusted = spikeTimes/sRate # this converts it into seconds
spikeTimes_adjusted

array([[1.56336464e-02],
       [3.17673027e-02],
       [5.91011834e-02],
       ...,
       [4.53866178e+03],
       [4.53866314e+03],
       [4.53866328e+03]])

In [10]:
# Now we save the spike times in seconds to .txt
spike_times_path = os.path.join(spikeSort_path,'spike_times.txt')
with open(spike_times_path, 'w') as f:
    for spiketime in spikeTimes_adjusted:
        f.write(f'{float(spiketime)}\n')
print('converted npy spiketimes to .txt file')

converted npy spiketimes to .txt file


In [11]:
# You need to change the names here to appropriate files that you want to access and treat 
tostreamName = os.path.join(input_path,run_name,(run_name + '_tcat.imec0.ap.xd_384_6_500.txt'))
fromstreamName = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_0_500.txt'))
events_1 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xia_0_1500.txt'))
events_1_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xia_0_1500_TPrime.txt'))
events_2 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xa_0_2000.txt'))
events_2_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xa_0_2000.txt'))
events_3 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_1_0.txt'))
events_3_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_1_0_TPrime.txt'))
events_4 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_2_0.txt'))
events_4_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_2_0_TPrime.txt'))
events_5 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_3_0.txt'))
events_5_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_3_0_TPrime.txt'))
events_6 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_4_0.txt'))
events_6_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_4_0_TPrime.txt'))
events_7 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_5_0.txt'))
events_7_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_5_0_TPrime.txt'))
events_8 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_6_0.txt'))
events_8_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_6_0_TPrime.txt'))
events_10 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xid_3_7_0.txt'))
events_10_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xid_3_7_0_TPrime.txt'))
events_9 = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_7_0.txt'))
events_9_out = os.path.join(input_path,run_name,(run_name + '_tcat.nidq.xd_3_7_0_TPrime.txt'))
                            

In [12]:
tostreamName

'D:\\NPX_data\\DATA\\08272024-NTBY4_g0\\08272024-NTBY4_g0\\08272024-NTBY4_g0_tcat.imec0.ap.xd_384_6_500.txt'

In [13]:
# This runs TPrime
command = f'runit.bat -syncperiod=1.0 \
    -tostream={tostreamName} \
    -fromstream=1,{fromstreamName} \
    -events=1,{events_1},{events_1_out} \
    -events=1,{events_2},{events_2_out} \
    -events=1,{events_3},{events_3_out} \
    -events=1,{events_4},{events_4_out} \
    -events=1,{events_5},{events_5_out} \
    -events=1,{events_6},{events_6_out} \
    -events=1,{events_7},{events_7_out} \
    -events=1,{events_8},{events_8_out} \
    -events=1,{events_9},{events_9_out} \
    -events=1,{events_10},{events_10_out}'

print(command)
print(os.system(f"cd /d {os.path.join(tools_path, 'TPrime-win')} & {command}"))
print(f'{run_name} done')


runit.bat -syncperiod=1.0     -tostream=D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.imec0.ap.xd_384_6_500.txt     -fromstream=1,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xd_3_0_500.txt     -events=1,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xia_0_1500.txt,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xia_0_1500_TPrime.txt     -events=1,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xa_0_2000.txt,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xa_0_2000.txt     -events=1,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xd_3_1_0.txt,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xd_3_1_0_TPrime.txt     -events=1,D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.nidq.xd_3_2_0.txt,D:\NPX_dat

In [14]:
# Cwaves variables
binName = os.path.join(input_path,run_name,(run_name + '_tcat.imec0.ap.bin'))
clusTableName = os.path.join(spikeSort_path,'clus_Table.npy')
clusTimeName = os.path.join(spikeSort_path,'spike_times.npy')
clusLblName = os.path.join(spikeSort_path,'spike_clusters.npy')
destinationCwaves = os.path.join(spikeSort_path,'C_waves')

clusTableName

'D:\\NPX_data\\DATA\\08272024-NTBY4_g0\\08272024-NTBY4_g0\\kilosort_output\\clus_Table.npy'

In [15]:
spike_clusters = np.load(clusLblName)
spike_clusters

spike_clusters = np.uint32(spike_clusters)
spike_clusters

clusLblName_new = os.path.join(spikeSort_path,'spike_clusters_NEW.npy')
np.save(clusLblName_new,spike_clusters)

In [18]:
getSortResults(spikeSort_path,0)


(1698, 1796402)

In [17]:
def getSortResults(output_dir, clu_version):
    # load results from phy for run logging and creation of the table for C_Waves

    cluLabel = np.load(os.path.join(output_dir, 'spike_clusters.npy'))
    spkTemplate = np.load(os.path.join(output_dir,'spike_templates.npy'))
    cluLabel = np.squeeze(cluLabel)
    spkTemplate = np.squeeze(spkTemplate)

    unqLabel, labelCounts = np.unique(cluLabel, return_counts = True)
    nTot = cluLabel.shape[0]
    nLabel = unqLabel.shape[0]
    maxLabel = np.max(unqLabel)

    templates = np.load(os.path.join(output_dir, 'templates.npy'))
    channel_map = np.load(os.path.join(output_dir, 'channel_map.npy'))
    channel_map = np.squeeze(channel_map)
    
    # read in inverse of whitening matrix
    w_inv = np.load((os.path.join(output_dir, 'whitening_mat_inv.npy')))
    nTemplate = templates.shape[0]
    
    # initialize peak_channels array
    peak_channels = np.zeros([nLabel,],'uint32')
    
   
    # After manual splits or merges, some labels will have spikes found with
    # different templats.
    # for each label in the list unqLabel, get the most common template
    # For that template (nt x nchan), multiply the the transpose (nchan x nt) by inverse of 
    # the whitening matrix (nchan x nchan); get max and min along tthe time axis (1)
    # to find the peak channel
    for i in np.arange(0,nLabel):
        curr_spkTemplate = spkTemplate[np.where(cluLabel==unqLabel[i])]
        template_mode = np.argmax(np.bincount(curr_spkTemplate))
        currT = templates[template_mode,:].T
        curr_unwh = np.matmul(w_inv, currT)
        currdiff = np.max(curr_unwh,1) - np.min(curr_unwh,1)
        peak_channels[i] = channel_map[np.argmax(currdiff)]

    clus_Table = np.zeros((maxLabel+1, 2), dtype='uint32')
    clus_Table[unqLabel, 0] = labelCounts
    clus_Table[unqLabel, 1] = peak_channels

    if clu_version == 0:
        np.save(os.path.join(output_dir, 'clus_Table.npy'), clus_Table)
    else:
        clu_Name = 'clus_Table_' + repr(clu_version) + '.npy'
        np.save(os.path.join(output_dir, clu_Name), clus_Table)
 
    return nTemplate, nTot

In [19]:
# This runs Cwaves
command = f'runit.bat -spikeglx_bin={binName} \
    -clus_table_npy={clusTableName} \
    -clus_time_npy={clusTimeName} \
    -clus_lbl_npy={clusLblName_new} \
    -dest={destinationCwaves} \
    -samples_per_spike=82 \
    -pre_samples=30 \
    -num_spikes=100 \
    -snr_radius=8 \
    -snr_radius_um=140'

print(command)
print(os.system(f"cd /d {os.path.join(tools_path, 'C_Waves-win')} & {command}"))
print(f'{run_name} done')


runit.bat -spikeglx_bin=D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\08272024-NTBY4_g0_tcat.imec0.ap.bin     -clus_table_npy=D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\kilosort_output\clus_Table.npy     -clus_time_npy=D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\kilosort_output\spike_times.npy     -clus_lbl_npy=D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\kilosort_output\spike_clusters_NEW.npy     -dest=D:\NPX_data\DATA\08272024-NTBY4_g0\08272024-NTBY4_g0\kilosort_output\C_waves     -samples_per_spike=82     -pre_samples=30     -num_spikes=100     -snr_radius=8     -snr_radius_um=140
0
08272024-NTBY4_g0 done
