In [1]:
chan_index = raw_input('Enter channel index to analyze:')
# print('Enter Open Ephys data directory path:')
# source_path = raw_input()

Enter channel index to analyze:5


In [None]:
source_path = '../data/OpenEphys_data/2019-04-26/2019-04-26_12-03-55/'

In [None]:
import time
t0 = time.time()
import OpenEphys
from kaveh.toolbox import common_avg_ref, butter_bandpass_filter
import Kwik
from matplotlib import pyplot as plt
import numpy as np
import os
import re
import scipy.signal

In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
f_names = []
for root, dirnames, filenames in os.walk(source_path):
    for filename in filenames:
        f_regex = re.compile(r".*CH(\d|\d\d)\.continuous$")
        if f_regex.match(filename):
            f_names = f_names + [os.path.join(root, filename)]
f_names = np.array(f_names)
chans = [int(f.split('.')[-2].split('_')[-1][2:]) for f in f_names]
f_names = f_names[np.argsort(chans)] # now sorted by channel number

In [None]:
chan_data = []
for contact in range(7):
    contact_reps = []
    for i in range(contact*4, contact*4 + 4):
        print('Reading {}...'.format(f_names[i]))
        file_content = OpenEphys.load(f_names[i])
        Fs = float(file_content['header']['sampleRate'])
        signal_filtered = butter_bandpass_filter(file_content['data'], 300, 3000, Fs, order=2 )
        contact_reps.append(signal_filtered)
    contact_reps = np.array(contact_reps)
    chan_data.append(np.mean(contact_reps, axis=0))
    print('-----------------------------------------------')
chan_data = np.array(chan_data)
contact_reps = None

In [None]:
common_avg_ref(chan_data)
chan_data = scipy.signal.detrend(chan_data)

In [None]:
from kaveh.sorting.spikesorter import SimpleSpikeSorter
dt = 1.0/Fs
sss = SimpleSpikeSorter(chan_data[chan_index, :], dt)
sss._pre_process()
sss._detect_spikes_minibatch()
sss._align_spikes()
sss.cs_num_gmm_components = 5
sss._cluster_spike_by_feature()
sss._cs_post_process()
print(sss.cs_indices.shape)

In [None]:
num_clusters = 4

from sklearn.mixture import GaussianMixture

pre_time = 0.0005
post_time = 0.005

pre_index = int(np.round(pre_time/sss.dt))
post_index = int(np.round(post_time/sss.dt))
aligned_cs = np.array([sss.voltage[i - pre_index : i + post_index] for i in sss.cs_indices])

import random

ss_indices = np.setdiff1d(sss.spike_indices, sss.cs_indices)
aligned_ss = np.array([sss.voltage[i - pre_index : i + post_index] for i in ss_indices[1:-2]])


mean_ss = np.mean(aligned_ss[random.sample(range(0, aligned_ss.shape[0]), sss.cs_indices.size), ], axis=0)

gmm = GaussianMixture(num_clusters, covariance_type = 'full').fit(aligned_cs)

cluster_labels = gmm.predict(aligned_cs)

clusters = []
for cn in np.arange(num_clusters):
    clusters.append(aligned_cs[np.where(cluster_labels == cn)])
    

In [None]:
# plot cluster means
colors = plt.cm.nipy_spectral(np.linspace(0,1,num_clusters))
legend_labels = []
for cn in np.arange(num_clusters):
    legend_labels.append('c{}({}) '.format(cn, clusters[cn].shape[0]))

import gc
gc.collect()
plt.figure(figsize=(8,5))
# ax2 = plt.subplot(122)
for cn in np.arange(num_clusters):
    plt.plot(np.mean(clusters[cn], axis=0), color = colors[cn], label = legend_labels[cn])
plt.plot(mean_ss, '--', label = 'Mean SS({})'.format(aligned_ss.shape[0]))
# plt.show()    
plt.legend() 

In [None]:
raw_input("Enter CS clusters (comma separated; example: 5,3,1): ")

In [None]:
clusters_to_pick = [int(c) for c in clusters_to_pick.split(',')]
cs_indices_to_pick = []
for cti in clusters_to_pick:
    cs_indices_to_pick = np.union1d(cs_indices_to_pick, sss.cs_indices[np.where(cluster_labels == cti)])
cs_indices = cs_indices_to_pick

In [None]:
t1 = time.time()
print(t1-t0)

## Saving detected CS and SS:

In [None]:
CS_csv_filename = os.path.join(source_path, 'channel_{}.CS.csv'.format(chan_index))
SS_csv_filename = os.path.join(source_path, 'channel_{}.SS.csv'.format(chan_index))

import csv
with open(CS_csv_filename, 'w+') as f:
    print('writing {} ... '.format(CS_csv_filename))
    f.seek(0)
    csvwriter = csv.writer(f, delimiter = ',')
    csvwriter.writerows(cs_indices.reshape(-1,1))
    
with open(SS_csv_filename, 'w+') as f:
    print('writing {} ... '.format(SS_csv_filename))
    f.seek(0)
    csvwriter = csv.writer(f, delimiter = ',')
    csvwriter.writerows(ss_indices.reshape(-1,1))