In [10]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

import os

In [11]:
# parameters
bin_size = 30000

sample_rate = 30000
sigma = 6

In [12]:
def load_files(spike_times_file, spike_clusters_file, cluster_group_file):
    spike_times = np.load(spike_times_file)
    spike_clusters = np.load(spike_clusters_file)

    clust_group = pd.read_csv(cluster_group_file, sep='\t')
    clust_label = clust_group.index[clust_group["label"] == "good"].tolist()

    return spike_times, spike_clusters, clust_label

In [14]:
def spike_times_by_cluster(spike_times, spike_clusters, clust_label):

    spike_times_all = []
    #clust_ind = clust_label.index[clust_label["label"] == "good"].tolist()

    for i in clust_label:
        spike_times_clust = spike_times[spike_clusters == i]

        if len(spike_times_clust) >= 50:
            spike_times_all.append(spike_times_clust)

    return spike_times_all


In [None]:
## function for calculating firing rate of each neuron(clsuters)
# Input:
# - spike_time_files = spike_times.npy
# - spike_clusters_file = spike_clusters.npy
# - bin_size = in sample numbers (30kHz), binwidth for calculating firing rate
# - cluster = select for which clusters to calculate firing rate

# Output:
# - fr [cluster #, (total length/ bin width)]: binned firing rate for each cluster
# - t_bins [(total length/ bin width),]: vector of binned time stamp
# - clust_num [int]: number of clusters

def firing_rate_calc(spike_times, spike_clusters, clust_label, bin_size, cluster):

    if cluster == "all":
        clust_tot = len(clust_label)
        clust_ind = clust_label
    else:
        clust_ind = cluster
        
    t_end = np.max(spike_times)
    t_bins = np.arange(0,t_end,bin_size)
    
    clust_num = len(clust_ind)
    fr = np.zeros((len(clust_ind), len(t_bins)-1))
    for i in range(len(clust_ind)):
        spikes_t_ind = spike_times[np.where(spike_clusters == clust_ind[i])[0]]
        spikes_count, edges = np.histogram(spikes_t_ind, t_bins)

        fr[i,:] = spikes_count

    return fr, t_bins[:-1], clust_num, clust_tot

In [None]:
## function for smoothing firing rate via gaussian filter
# Input:
# - fr
# - sigma: standard deviation of gaussian kernel

# Output:
# - fr_smooth

def firing_rate_smooth(fr, sigma):
    fr_smooth = np.zeros(fr.shape)
    for i in range(len(fr_smooth)):
        fr_smooth[i,:] = gaussian_filter1d(fr[i,:], sigma)

    return fr_smooth

In [None]:
def firing_rate_metics(spike_times_file, spike_clusters_file):
    spike_times = np.load(spike_times_file)
    spike_clusters = np.load(spike_clusters_file)

    clust_ind = np.unique(spike_clusters)
    cluster_spike_count = np.zeros(len(clust_ind))

    for i in range(len(clust_ind)):
        count = len(np.where(spike_clusters == clust_ind[i])[0])
        cluster_spike_count[i] = count

    firing_rate = cluster_spike_count/(np.max(spike_times)/30000)

    df = pd.DataFrame({'cluster_id':clust_ind, 'firing_rate':firing_rate})
    
    return df

In [None]:
# MAIN SCRIPT FOR CALCULATING FIRING RATE

# going through all the different recording sessions
t_stamp = ["0900", "1000", "1100", "1200", "1300", "1400", "1500", "1600", "1700", "1800", "1900", "2000", "2100"]
file_dir = "Z:/Wu_sleep/m2/SD1/"
file_date = "20240923"

mean_fr_sessions = []

for s in range(len(t_stamp)):
    file_path = file_date + "_SD1_test_" + t_stamp[s] + "_g0/catgt_" + file_date + "_SD1_test_" + t_stamp[s] + "_g0/" + file_date + "_SD1_test_" + t_stamp[s] + "_g0_imec0/imec0_ks25"
    spike_times_file = file_dir + file_path + "/spike_times.npy"
    spike_clusters_file = file_dir + file_path + "/spike_clusters.npy"
    cluster_label_file = file_dir + file_path + "/cluster_group.tsv"
    

    fr, t_bins, clust_num, clust_tot = firing_rate_calc(spike_times_file, spike_clusters_file, cluster_label_file, bin_size, "all")
    fr_smooth = firing_rate_smooth(fr, sigma)

    # taking average of binned firing rate for each cluster
    fr_smooth_avg = np.average(fr_smooth, axis=1)

    # save total cluster fr data in vector for each session
    mean_fr_sessions.append(fr_smooth_avg)

    # update post-curation metrics file with firing rate for neuron tracking
    fr_df = firing_rate_metics(spike_times_file, spike_clusters_file)
    fr_df.to_csv(os.path.join(file_dir + file_path, r"metrics_curated.csv"))

    

In [9]:
# identifying burst segments based on ISI

spike_t_delta = []

for i in range(len(spikes)):
    clust_t_delta = np.diff(spikes[i])
    spike_t_delta.append(sum(clust_t_delta <= (0.5*sample_rate)))

In [None]:
# plot mean Firing Rate (Violin Plot)

plt.violinplot(mean_fr_sessions,showmedians=True, showextrema=False)

plt.ylim([0,50])
plt.title("M2 Firing Rate")
plt.xticks(np.arange(13)+1, ["SD1", "SD2", "SD3", "SD4", "SD5", "SD6", "RE1", "RE2", "RE3", "RE4", "RE5", "RE6", "RE7"])
plt.xlabel("Sessions")
plt.ylabel("FR (Hz)")
plt.show()

In [None]:
# plot mean Firing Rate (Boxplot)

plt.boxplot(mean_fr_sessions, sym="")

plt.ylim([0,50])
plt.title("M1 Firing Rate")
plt.xticks(np.arange(13)+1, ["SD1", "SD2", "SD3", "SD4", "SD5", "SD6", "RE1", "RE2", "RE3", "RE4", "RE5", "RE6", "RE7"])
plt.xlabel("Sessions")
plt.ylabel("FR (Hz)")
plt.show()