In [2]:
import os
import sys
import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
sys.path.append(r'C:\Users\lesliec\code')

In [29]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp
from ecephys_spike_sorting.ecephys_spike_sorting.modules.quality_metrics import *
from ecephys_spike_sorting.ecephys_spike_sorting.common.utils import load_kilosort_data
from ecephys_spike_sorting.ecephys_spike_sorting.common.epoch import Epoch

#### Functions from metrics

In [30]:
def get_unit_pcs(unit_id,
                 spike_clusters,
                 spike_templates,
                 pc_feature_ind,
                 pc_features,
                 channels_to_use,
                 subsample):

    """ Return PC features for one unit

    Inputs:
    -------
    unit_id : Int
        ID for this unit
    spike_clusters : np.ndarray
        Cluster labels for each spike
    spike_templates : np.ndarry
        Template labels for each spike
    pc_feature_ind : np.ndarray
        Channels used for PC calculation for each unit
    pc_features : np.ndarray
        Array of all PC features
    channels_to_use : np.ndarray
        Channels to use for calculating metrics
    subsample : Int
        maximum number of spikes to return

    Output:
    -------
    unit_PCs : numpy.ndarray (float)
        PCs for one unit (num_spikes x num_PCs x num_channels)

    """


    inds_for_unit = np.where(spike_clusters == unit_id)[0]

    spikes_to_use = np.random.permutation(inds_for_unit)[:subsample]

    unique_template_ids = np.unique(spike_templates[spikes_to_use])

    unit_PCs = []

    for template_id in unique_template_ids:

        index_mask = spikes_to_use[np.squeeze(spike_templates[spikes_to_use]) == template_id]
        these_inds = pc_feature_ind[template_id, :]

        pc_array = []

        for i in channels_to_use:

            if np.isin(i, these_inds):
                channel_index = np.argwhere(these_inds == i)[0][0]
                pc_array.append(pc_features[index_mask, :, channel_index])
            else:
                return None

        unit_PCs.append(np.stack(pc_array, axis=-1))

    if len(unit_PCs) > 0:

        return np.concatenate(unit_PCs)
    else:
        return None

Set directory

In [5]:
rec_dir = r"F:\EEG_exp\mouse543396\estim_vis1_2020-09-18_12-04-46\experiment1\probeC_sorted\continuous\Neuropix-3a-100.0"

Use load_kilosort_data function to get all sorts of data

In [6]:
spike_times, spike_clusters, spike_templates, amplitudes, templates, channel_map, clusterIDs, cluster_quality, pc_features, pc_feature_ind = \
            load_kilosort_data(rec_dir, 30000, use_master_clock = False, include_pcs = True)

In [7]:
epochs = [Epoch('complete_session', 0, np.inf)]

In [8]:
total_units = len(np.unique(spike_clusters))
total_epochs = len(epochs)

In [9]:
print(total_units)
print(epochs)

362
[<ecephys_spike_sorting.ecephys_spike_sorting.common.epoch.Epoch object at 0x00000191125CCA58>]


In [10]:
epoch = epochs[0]

In [11]:
in_epoch = (spike_times > epoch.start_time) * (spike_times < epoch.end_time)

### calculate_pc_metrics

In [12]:
## def of calculate_pc_metrics function
spike_clusters = spike_clusters[in_epoch]
spike_templates = spike_templates[in_epoch]
# total_units = total_units
pc_features = pc_features[in_epoch,:,:]
# pc_feature_ind = pc_feature_ind
num_channels_to_compare = 13
max_spikes_for_cluster = 500
max_spikes_for_nn = 10000
n_neighbors = 4

In [13]:
half_spread = int((num_channels_to_compare - 1) / 2)
print(half_spread)

6


In [14]:
cluster_ids = np.unique(spike_clusters)
template_ids = np.unique(spike_templates)

template_peak_channels = np.zeros((len(template_ids),), dtype='uint16')
cluster_peak_channels = np.zeros((len(cluster_ids),), dtype='uint16')

In [15]:
for idx, template_id in enumerate(template_ids):
    for_template = np.squeeze(spike_templates == template_id)
    pc_max = np.argmax(np.mean(pc_features[for_template, 0, :], 0))
    template_peak_channels[idx] = pc_feature_ind[template_id, pc_max]

for idx, cluster_id in enumerate(cluster_ids):
    for_unit = np.squeeze(spike_clusters == cluster_id)
    templates_for_unit = np.unique(spike_templates[for_unit])
    template_positions = np.where(np.isin(template_ids, templates_for_unit))[0]
    cluster_peak_channels[idx] = np.median(template_peak_channels[template_positions])

### calculate_pc_metrics_one_cluster

In [45]:
cluster_id = 478 ## have ahd problems with 47? and 478
idx = np.squeeze(np.argwhere(cluster_ids == cluster_id))

# cluster_peak_channels = cluster_peak_channels
# idx = idx
# cluster_id = cluster_id
# cluster_ids = cluster_ids
# half_spread
# pc_features
# pc_feature_ind
# spike_clusters
# spike_templates
# max_spikes_for_cluster
# max_spikes_for_nn
# n_neighbors

In [46]:
peak_channel = cluster_peak_channels[idx]
num_spikes_in_cluster = np.sum(spike_clusters == cluster_id)

half_spread_down = peak_channel \
    if peak_channel < half_spread \
    else half_spread

half_spread_up = np.max(pc_feature_ind) - peak_channel \
    if peak_channel + half_spread > np.max(pc_feature_ind) \
    else half_spread

channels_to_use = np.arange(peak_channel - half_spread_down, peak_channel + half_spread_up + 1)
units_in_range = cluster_ids[np.isin(cluster_peak_channels, channels_to_use)]

spike_counts = np.zeros(units_in_range.shape)

In [47]:
print(units_in_range)

[  7  34  35  38  39  40  41  42 475 476 478]


In [48]:
for idx2, cluster_id2 in enumerate(units_in_range):
    spike_counts[idx2] = np.sum(spike_clusters == cluster_id2)

if num_spikes_in_cluster > max_spikes_for_cluster:
    relative_counts = spike_counts / num_spikes_in_cluster * max_spikes_for_cluster
else:
    relative_counts = spike_counts

In [49]:
print(spike_counts)

[1.2742e+04 3.4920e+04 1.0851e+04 3.3390e+03 6.8300e+02 3.5100e+03
 1.6837e+04 3.8470e+03 2.7000e+01 1.0310e+03 2.4660e+03]


In [50]:
all_pcs = np.zeros((0, pc_features.shape[1], channels_to_use.size))
all_labels = np.zeros((0,))

In [51]:
for idx2, cluster_id2 in enumerate(units_in_range):

    subsample = int(relative_counts[idx2])

    pcs = get_unit_pcs(cluster_id2, spike_clusters, spike_templates,
                       pc_feature_ind, pc_features, channels_to_use,
                       subsample)

    if pcs is not None and len(pcs.shape) == 3:

        labels = np.ones((pcs.shape[0],)) * cluster_id2

        all_pcs = np.concatenate((all_pcs, pcs),0)
        all_labels = np.concatenate((all_labels, labels),0)

all_pcs = np.reshape(all_pcs, (all_pcs.shape[0], pc_features.shape[1]*channels_to_use.size))

In [52]:
print(all_pcs.shape)

(500, 39)


In [53]:
print(all_labels)
print(len(all_labels))

[478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478. 478.
 478. 

In [54]:
if ((all_pcs.shape[0] > 10) and (cluster_id in all_labels) and (len(channels_to_use) > 0) and (len(units_in_range) > 1)):
    print('Passed, continue with d prime')

Passed, continue with d prime
