In [1]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from phylib.io.model import load_model
from phylib.utils.color import selected_cluster_color

In [2]:
def extractClusterWaveforms(params, clusterID, nSampleWaveforms, plotting):
# Modified from phy customization "Extracting waveforms" function

    # First, we load the TemplateModel.
    model = load_model(params)  # first argument: path to params.py
    
    model.n_samples_waveforms = nSampleWaveforms

    # We obtain the cluster id from the command-line arguments.
    cluster_id = int(clusterID)  # second argument: cluster index

    # We get the waveforms of the cluster.
    waveforms = model.get_cluster_spike_waveforms(cluster_id)
    n_spikes, n_samples, n_channels_loc = waveforms.shape

    # We get the channel ids where the waveforms are located.
    channel_ids = model.get_cluster_channels(cluster_id)
    
    if plotting:
        # We plot the waveforms on the first four channels.
        f, axes = plt.subplots(1, min(4, n_channels_loc), sharey=True)
        for ch in range(min(4, n_channels_loc)):
            axes[ch].plot(waveforms[::100, :, ch].T, c=selected_cluster_color(0, .05))
            axes[ch].set_title("channel %d" % channel_ids[ch])
        plt.show()
    
    return waveforms, channel_ids

In [3]:
def getWaveforms(recID, plotting, nSampleWaveforms = 82, nBestChannels = 3):
    
    paramsFilename=recID+"params.py"
    clusterFilename=recID+"cluster_group.tsv"

    clusterGroups = pd.read_csv(clusterFilename, sep='\t')
    clusterIDs=list(clusterGroups[(clusterGroups["group"]=="good")]["cluster_id"])

    waveforms = np.empty([nSampleWaveforms, nBestChannels, len(clusterIDs)])
    waveforms_ids = np.empty([nBestChannels, len(clusterIDs)])
    for i, cid in enumerate (clusterIDs):
            tmpWF, channel_ids = extractClusterWaveforms(paramsFilename, cid, nSampleWaveforms, plotting)
            waveforms[: , :, i] = np.mean(tmpWF[:, :, :nBestChannels], axis=0)
            waveforms_ids[:, i] = channel_ids[:nBestChannels]
            print("Finished extracting cluster "+str(cid))
        #Save all waveform for each su cluster 
            if (cid==97) | (cid==119):
                np.save(recID+"waveforms\\"+str(cid)+".npy", tmpWF) ## Uncomment to save; large amount of disk space required
                print("Saved all waveforms of "+str(cid))
    # Save average waveform for each su cluster    
    np.save(recID+"su_Waveforms.npy", waveforms)
    np.save(recID+"su_Waveforms_ids.npy", waveforms_ids)

In [5]:
# tab = pd.read_csv("C:\\Users\Butt Lab\Documents\GitHub\InVivoEphys_Analysis\V1_InVivo.csv")
# recIDs=np.unique(tab.MouseID)

mainFolder="G:\\Nesta_SpikeSorting"
recIDs=["NK272"]
for rec in recIDs:
    fileDir = mainFolder+"\\"+rec+"\\"
    print('Start extracting '+rec)
    getWaveforms(fileDir, 0, 200)
    print("Saved waveforms")

Start extracting NK272
Finished extracting cluster 16
Finished extracting cluster 52
Finished extracting cluster 125
Finished extracting cluster 146
Finished extracting cluster 148
Saved waveforms
