# Figure 1 : Tracé de l'activité avec le signal brut

In [1]:
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
from tqdm import tqdm
import csv

#Filter functions
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

# Find nearest point
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


# Parameters
folder = 'A006_a17/'
num_spikes_to_extract = 1000
window = 30

clust_chan_tuples = [(81,31),
                    (77,30),
                    (83,29),
                    (69,28),
                    (46,27),
                    (80,26),
                    (96,25),
                    (91, 24),
                    (95, 23),
                    (2, 22),
                    (94, 21),
                    (97, 20),
                    (72,19),
                    (66, 18),
                    (30, 18),
                    (74,17),
                    (92,16),
                    (88,15),
                    (85,14),
                    (86,13),
                    (17,12),
                    (64, 11),
                    (82,10)]

fs = 30000.0
lowcut = 300.0
highcut = 3000.0
order = 6

# Open raw file
f = open(folder + 'converted_data.bin', "r")
a = np.fromfile(f, dtype=np.int16)
a = np.reshape(a, (-1, 32))

# Load spikes
spike_times = np.load(folder+'spike_times.npy')
spike_clusters = np.load(folder+'spike_clusters.npy')

# group spike with cluster identity
spike_cluster_table = []
for i, spike in enumerate(spike_times):
    spike_cluster_table.append((spike[0], spike_clusters[i]))

# extract cluster metadata
cluster_groups = []
with open(folder+'cluster_groups.csv', 'r') as csvFile:
    reader = csv.reader(csvFile)
    for row in reader:
        if row[0].split('\t')[1] == 'good':
            cluster_groups.append(int(row[0].split('\t')[0]))
good_spikes = []
for good_cluster in cluster_groups:
    temp_lst = []
    for spike_cluster in spike_cluster_table:
        if spike_cluster[-1] == good_cluster:
            temp_lst.append(spike_cluster[0])

    good_spikes.append(temp_lst)

plot_window_beg = 56000
plot_window_end = plot_window_beg+30000

#Main loop, average waveforms and get points
mean_waveforms = []
carac_points = []
fig = plt.figure(figsize = (5,15))
for it, cluster_tuple in enumerate(clust_chan_tuples) :
    cluster_channel = cluster_tuple[1] - 1
    index_cluster = cluster_groups.index(cluster_tuple[0])

    y = butter_bandpass_filter(a[:, cluster_channel], lowcut, highcut, fs, order)
    
    #plotting time

    plt.subplot(len(clust_chan_tuples), 1, it+1)
    plt.plot(a[:, cluster_channel][plot_window_beg:plot_window_end], c='gray', alpha=.4)
    
    window_spiketimes = np.where((np.asarray(good_spikes[it]) > plot_window_beg) & (np.asarray(good_spikes[it]) < plot_window_end))[0]
    for spiketimes in window_spiketimes:
        plt.axvline(good_spikes[it][spiketimes]-plot_window_beg,
                   ymin = .5, ymax = .8, c = 'k')
        
    plt.axhline(0, c = 'gray', linewidth = 2, alpha = .8)
    plt.axis('off')
    #plt.show()
    #plt.pause(.05)
    
plt.subplots_adjust(bottom = 0, wspace=0, hspace=0)
fig.savefig('raw_and_spikes.svg'%cluster_channel, format = 'svg', transparent = True)
plt.show()

In [None]:
it = 0
window_spiketimes = np.where((np.asarray(good_spikes[it]) > plot_window_beg) & (np.asarray(good_spikes[it]) < plot_window_end))[0]
for spiketimes in window_spiketimes:
    print(good_spikes[it][spiketimes])

In [None]:
len(y[plot_window_beg:plot_window_end])

In [None]:
plot_window_beg

In [None]:
plot_window_end

In [None]:
plt.close('all')