In [10]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
#plt.ion()

#matplotlib.use('Qt5agg')

import os

from scipy.ndimage import gaussian_filter1d

import burst_detector as bd

In [11]:
#rec_time = "0900"
#file_path = "20240914_SD1_test_" + rec_time + "_g0/catgt_20240914_SD1_test_" + rec_time + "_g0/20240914_SD1_test_" + rec_time + "_g0_imec0/imec0_ks25"

# parameters
bin_size = 30000

sample_rate = 30000
sigma = 6

In [12]:
def load_files(spike_times_file, spike_clusters_file, cluster_group_file):
    spike_times = np.load(spike_times_file)
    spike_clusters = np.load(spike_clusters_file)

    clust_group = pd.read_csv(cluster_group_file, sep='\t')
    clust_label = clust_group.index[clust_group["label"] == "good"].tolist()

    return spike_times, spike_clusters, clust_label

In [14]:
def spike_times_by_cluster(spike_times, spike_clusters, clust_label):

    spike_times_all = []
    #clust_ind = clust_label.index[clust_label["label"] == "good"].tolist()

    for i in clust_label:
        spike_times_clust = spike_times[spike_clusters == i]

        if len(spike_times_clust) >= 50:
            spike_times_all.append(spike_times_clust)

    return spike_times_all


In [15]:
spike_times_file = r"Z:\Wu_sleep\m1\SD1\20240914_SD1_test_0900_g0\catgt_20240914_SD1_test_0900_g0\20240914_SD1_test_0900_g0_imec0\imec0_ks25\spike_times.npy"
spike_clusters_file =r"Z:\Wu_sleep\m1\SD1\20240914_SD1_test_0900_g0\catgt_20240914_SD1_test_0900_g0\20240914_SD1_test_0900_g0_imec0\imec0_ks25\spike_clusters.npy"
cluster_group_file = r"Z:\Wu_sleep\m1\SD1\20240914_SD1_test_0900_g0\catgt_20240914_SD1_test_0900_g0\20240914_SD1_test_0900_g0_imec0\imec0_ks25\cluster_group.tsv"

spike_times, spike_clusters, clust_label = load_files(spike_times_file, spike_clusters_file, cluster_group_file)
spikes = spike_times_by_cluster(spike_times, spike_clusters, clust_label)
#spikes = spike_times_by_cluster(spike_times_file, spike_clusters_file, clust_label)

In [None]:
## function for calculating firing rate of each neuron(clsuters)
# Input:
# - spike_time_files = spike_times.npy
# - spike_clusters_file = spike_clusters.npy
# - bin_size = in sample numbers (30kHz), binwidth for calculating firing rate
# - cluster = select for which clusters to calculate firing rate

# Output:
# - fr [cluster #, (total length/ bin width)]: binned firing rate for each cluster
# - t_bins [(total length/ bin width),]: vector of binned time stamp
# - clust_num [int]: number of clusters

def firing_rate_calc(spike_times, spike_clusters, clust_label, bin_size, cluster):

    if cluster == "all":
        clust_tot = len(clust_label)
        clust_ind = clust_label
    else:
        clust_ind = cluster
        
    t_end = np.max(spike_times)
    t_bins = np.arange(0,t_end,bin_size)
    
    clust_num = len(clust_ind)
    fr = np.zeros((len(clust_ind), len(t_bins)-1))
    for i in range(len(clust_ind)):
        spikes_t_ind = spike_times[np.where(spike_clusters == clust_ind[i])[0]]
        spikes_count, edges = np.histogram(spikes_t_ind, t_bins)

        fr[i,:] = spikes_count

    return fr, t_bins[:-1], clust_num, clust_tot

In [None]:
## function for smoothing firing rate via gaussian filter
# Input:
# - fr
# - sigma: standard deviation of gaussian kernel

# Output:
# - fr_smooth

def firing_rate_smooth(fr, sigma):
    fr_smooth = np.zeros(fr.shape)
    for i in range(len(fr_smooth)):
        fr_smooth[i,:] = gaussian_filter1d(fr[i,:], sigma)

    return fr_smooth

In [None]:
cluster_ind = np.unique(spike_clusters)

cluster_spike_count = np.zeros(len(cluster_ind))
for i in range(len(cluster_ind)):
    count = len(np.where(spike_clusters == cluster_ind[i])[0])
    cluster_spike_count[i] = count

firing_rate = cluster_spike_count/(np.max(spike_times)/30000)

df = pd.DataFrame({'cluster_id':cluster_ind, 'firing_rate':firing_rate})

In [None]:
def firing_rate_metics(spike_times_file, spike_clusters_file):
    spike_times = np.load(spike_times_file)
    spike_clusters = np.load(spike_clusters_file)

    clust_ind = np.unique(spike_clusters)
    cluster_spike_count = np.zeros(len(clust_ind))

    for i in range(len(clust_ind)):
        count = len(np.where(spike_clusters == clust_ind[i])[0])
        cluster_spike_count[i] = count

    firing_rate = cluster_spike_count/(np.max(spike_times)/30000)

    df = pd.DataFrame({'cluster_id':clust_ind, 'firing_rate':firing_rate})
    
    return df

In [None]:
# MAIN SCRIPT FOR CALCULATING FIRING RATE

# going through all the different recording sessions
t_stamp = ["0900", "1000", "1100", "1200", "1300", "1400", "1500", "1600", "1700", "1800", "1900", "2000", "2100"]
file_dir = "Z:/Wu_sleep/m2/SD1/"
file_date = "20240923"

mean_fr_sessions = []

for s in range(len(t_stamp)):
    file_path = file_date + "_SD1_test_" + t_stamp[s] + "_g0/catgt_" + file_date + "_SD1_test_" + t_stamp[s] + "_g0/" + file_date + "_SD1_test_" + t_stamp[s] + "_g0_imec0/imec0_ks25"
    spike_times_file = file_dir + file_path + "/spike_times.npy"
    spike_clusters_file = file_dir + file_path + "/spike_clusters.npy"
    cluster_label_file = file_dir + file_path + "/cluster_group.tsv"
    

    fr, t_bins, clust_num, clust_tot = firing_rate_calc(spike_times_file, spike_clusters_file, cluster_label_file, bin_size, "all")
    fr_smooth = firing_rate_smooth(fr, sigma)

    # taking average of binned firing rate for each cluster
    fr_smooth_avg = np.average(fr_smooth, axis=1)

    # save total cluster fr data in vector for each session
    mean_fr_sessions.append(fr_smooth_avg)

    # update post-curation metrics file with firing rate for neuron tracking
    fr_df = firing_rate_metics(spike_times_file, spike_clusters_file)
    fr_df.to_csv(os.path.join(file_dir + file_path, r"metrics_curated.csv"))

    

In [None]:
t_stamp = ["0900", "1000", "1100", "1200", "1300", "1400", "1500", "1600", "1700", "1800", "1900", "2000", "2100"]
file_dir = "Z:/Wu_sleep/m1/SD1/"
file_date = "20240914"

clust_num_all = np.zeros(13)
clust_tot_all = np.zeros(13)
for s in range(len(t_stamp)):
    file_path = file_date + "_SD1_test_" + t_stamp[s] + "_g0/catgt_" + file_date + "_SD1_test_" + t_stamp[s] + "_g0/" + file_date + "_SD1_test_" + t_stamp[s] + "_g0_imec0/imec0_ks25"
    spike_times_file = file_dir + file_path + "/spike_times.npy"
    spike_clusters_file = file_dir + file_path + "/spike_clusters.npy"
    cluster_label_file = file_dir + file_path + "/cluster_group.tsv"
    

    fr, t_bins, clust_num, clust_tot= firing_rate_calc(spike_times_file, spike_clusters_file, cluster_label_file, bin_size, "all")
    clust_num_all[s] = clust_num
    clust_tot_all[s] = clust_tot


In [16]:
spikes_t = np.array(spikes[127]/sample_rate, dtype=float)
bursts, q, a = bd.find_bursts(spikes_t)
a[q]

print(np.max(q))
print(a)

1
[0.17036998 1.05768911]


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [21]:
a[q]

array([0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       0.17036998, 1.05768911, 1.05768911, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       1.05768911, 1.05768911, 1.05768911, 1.05768911, 1.05768911,
       1.05768911, 1.05768911, 1.05768911, 1.05768911, 1.05768911,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 1.05768911,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 1.05768911, 0.17036998, 0.17036998,
       0.17036998, 1.05768911, 1.05768911, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 1.05768911, 0.17036998, 0.17036998,
       0.17036998, 0.17036998, 0.17036998, 0.17036998, 0.17036

In [None]:
pp = np.array(spikes[313]/sample_rate, dtype=float)
bursts, q ,a = bd.find_bursts(pp)

In [6]:
sp = spikes[272]
burst, q, a = bd.find_bursts(sp)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [None]:
q_all = []
burst_thres = 20

for i in range(len(spikes)):
    spike_t_clust = np.array(spikes[i]/sample_rate, dtype=float)
    bursts, q ,a = bd.find_bursts(spike_t_clust)

    q_all.append(np.sum(a[q] > burst_thres))


In [9]:
spike_t_delta = []

for i in range(len(spikes)):
    clust_t_delta = np.diff(spikes[i])
    spike_t_delta.append(sum(clust_t_delta <= (0.5*sample_rate)))

In [10]:
plt.plot(spike_t_delta)
plt.show()

  plt.show()


In [38]:
q_all

[1302,
 41769,
 550,
 6,
 1610,
 897,
 1106,
 0,
 559,
 1617,
 797,
 796,
 947,
 1546,
 1828,
 2149,
 1299,
 737,
 0,
 1675,
 1882,
 8,
 0,
 582,
 0,
 1500,
 1131,
 0,
 1956,
 1447,
 1191,
 8,
 0,
 3,
 1608,
 1642,
 1593,
 1663,
 1903,
 15045,
 0,
 3,
 1998,
 2068,
 927,
 878,
 14590,
 0,
 1327,
 3217,
 2527,
 0,
 2083,
 2405,
 35234,
 2301,
 1,
 2,
 1964,
 3046,
 0,
 1902,
 6,
 1941,
 2,
 1445,
 0,
 2039,
 2371,
 0,
 2678,
 793,
 0,
 3069,
 3599,
 2347,
 7,
 2784,
 1695,
 535,
 2996,
 2291,
 0,
 564,
 546,
 1718,
 2574,
 2724,
 707,
 1,
 6,
 0,
 13292,
 2898,
 1,
 194,
 20,
 2643,
 419,
 3442,
 871,
 1,
 2272,
 10,
 2394,
 3197,
 41134,
 2257,
 9,
 1991,
 1839,
 2,
 3,
 0,
 1638,
 1783,
 2976,
 0,
 2550,
 2265,
 2245,
 1729,
 16806,
 2172,
 0,
 0,
 2149,
 0,
 844,
 1163,
 528,
 1130,
 610,
 416,
 1870,
 1328,
 887,
 2189,
 659,
 1218,
 1477,
 0,
 0,
 0,
 1888,
 1866,
 1969,
 537,
 1066,
 8,
 678,
 970,
 750,
 1,
 14788,
 2891,
 1876,
 12506,
 13233,
 2131,
 1731,
 1,
 14034,
 0,
 1483

In [36]:
spike_t_delta

[8309,
 41769,
 2106,
 1720,
 6490,
 4099,
 5912,
 627,
 3625,
 5375,
 5338,
 5296,
 7998,
 6539,
 6671,
 5433,
 2759,
 6275,
 2,
 5665,
 6167,
 1607,
 2,
 4419,
 1,
 7249,
 6619,
 3,
 8265,
 7014,
 9329,
 1594,
 9,
 9,
 7998,
 9331,
 8177,
 4420,
 7861,
 15043,
 24,
 8,
 5605,
 7540,
 3604,
 3955,
 14577,
 112,
 2850,
 11095,
 7485,
 6,
 9351,
 11459,
 35234,
 6931,
 6,
 7,
 8978,
 11342,
 2,
 9382,
 1321,
 9133,
 8,
 2284,
 10,
 12133,
 11681,
 20,
 9159,
 3324,
 1,
 11497,
 12096,
 10513,
 38,
 11951,
 7165,
 3604,
 10700,
 8044,
 13,
 869,
 3596,
 7187,
 10268,
 10661,
 3844,
 6,
 59,
 4,
 13282,
 10419,
 3,
 311,
 25,
 6636,
 2001,
 8720,
 3702,
 4,
 7172,
 867,
 10372,
 9169,
 41134,
 9335,
 25,
 6382,
 6622,
 9,
 7,
 6,
 12055,
 5797,
 10801,
 6,
 9878,
 7055,
 7980,
 6888,
 16788,
 8178,
 5,
 0,
 7702,
 7,
 6083,
 3615,
 2053,
 4438,
 2136,
 8735,
 4069,
 5101,
 8900,
 10087,
 3430,
 8099,
 3248,
 7,
 3,
 2,
 4248,
 5145,
 5606,
 2122,
 4418,
 1584,
 2818,
 8378,
 2375,
 13,
 1

In [None]:
plt.violinplot(mean_fr_sessions,showmedians=True, showextrema=False)

plt.ylim([0,50])
plt.title("M2 Firing Rate")
plt.xticks(np.arange(13)+1, ["SD1", "SD2", "SD3", "SD4", "SD5", "SD6", "RE1", "RE2", "RE3", "RE4", "RE5", "RE6", "RE7"])
plt.xlabel("Sessions")
plt.ylabel("FR (Hz)")
plt.show()

In [None]:
plt.boxplot(mean_fr_sessions, sym="")

plt.ylim([0,50])
plt.title("M1 Firing Rate")
plt.xticks(np.arange(13)+1, ["SD1", "SD2", "SD3", "SD4", "SD5", "SD6", "RE1", "RE2", "RE3", "RE4", "RE5", "RE6", "RE7"])
plt.xlabel("Sessions")
plt.ylabel("FR (Hz)")
plt.show()