In [None]:
# Load unit stability information

import os
import numpy as np
from scipy.stats import ttest_rel, wilcoxon
from scipy.io import loadmat
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

%matplotlib widget


main_dir = '/home/kouroshmaboudi/Documents/Learned_tuning_Python/Datasets/'
rr = os.listdir(main_dir)

current_sessions = [x for x in range(5)] # Grosmark sessions
# current_sessions = [x for x in range(6, 11)] # Giri sessions
# current_sessions = [x for x in range(11, 17) if x not in [12, 13]] # Hiro's sessions
# current_sessions.append(5)

number_of_sessions = len(current_sessions)

num_units = np.empty((number_of_sessions,), dtype=int)
each_unit_session_number = np.empty((number_of_sessions,), dtype=object) # to keep track of session identity for each unit when we pool together and sort the units from diferent sessions

spatial_tunings_maze = np.empty((number_of_sessions,), dtype=object) # spatial tunings on MAZE

spikes_all_sessions = np.empty((number_of_sessions,), dtype=object)

for session_idx in range(number_of_sessions):

    session_number = current_sessions[session_idx]
    session_name = rr[session_number]

    print(session_name + " " + str(session_idx+1) + "/" + str(number_of_sessions))

    session_dataset_path = os.path.join(main_dir, session_name)


    # Load epochs information
    filename = f'{session_name}.fileInfo_for_python.mat'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = loadmat(file_path)
    session_info = mat_file["fileInfo"]

    epochs = session_info["behavior"][0][0][0][0]["time"]
    epoch_durations = epochs[:, 1] - epochs[:, 0]


    # Load spike data
    filename = f'{session_name}.spikes_for_python.mat'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = loadmat(file_path)
    spikes_pyr = mat_file["spikes_pyr"]

    spikes_pyr = loadmat(os.path.join(session_dataset_path, session_name + '.spikes_for_python.mat'))['spikes_pyr']
    
    if session_number in [6, 7]: # RatN and RatS
        num_units_total = spikes_pyr["spatialTuning_smoothed"].shape[0] # for RatN only
    else:
        num_units_total = spikes_pyr["spatialTuning_smoothed"][0].shape[0]

    num_pos_bins = spikes_pyr["spatialTuning_smoothed"][0][0]['uni'][0][0].size


    # Load unit stability information
    filename = f'{session_name}.cluster_quality_by_block'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = loadmat(file_path)
    cluster_quality_by_block = mat_file['cluster_quality_by_block'][0]


    spikes = [] # spike data and place field info of each unit
    running_directions = {'LR', 'RL', 'uni'}

    iter = 0
    for unit in range(num_units_total):

        unit_spikes = dict()

        if session_number in [9, 10]: # for Rat V sessions
            unit_spikes['spike_times'] = spikes_pyr['time'][0][unit]
            unit_spikes['shank_id']    = spikes_pyr['id'][0][unit][0][1] 
            unit_spikes['cluster_id']  = spikes_pyr['id'][0][unit][0][0]

        elif session_number in [6, 7]: # for RatN and RatS
            unit_spikes['spike_times'] = spikes_pyr['time'][unit][0] 
            unit_spikes['shank_id']    = spikes_pyr['id'][unit][0][0][0]-1
            unit_spikes['cluster_id']  = spikes_pyr['id'][unit][0][0][1]

        elif session_number == 8: # RatU  
            unit_spikes['spike_times'] = spikes_pyr['time'][0][unit]
            unit_spikes['shank_id']    = spikes_pyr['id'][0][unit][0][0] # shank indices already starts at zero
            unit_spikes['cluster_id']  = spikes_pyr['id'][0][unit][0][1]
        else:
            unit_spikes['spike_times'] = spikes_pyr['time'][0][unit]
            unit_spikes['shank_id']    = spikes_pyr['id'][0][unit][0][0]-1 # need to go one down for the other datasets
            unit_spikes['cluster_id']  = spikes_pyr['id'][0][unit][0][1]


        # Extract the cluster quality information by block for the current unit  
        curr_unit_idx = np.where(cluster_quality_by_block['cluster_ids'][unit_spikes['shank_id']] == unit_spikes['cluster_id'])[0]
        
        spike_amplitude_by_block = cluster_quality_by_block['spike_amplitude_by_block'][unit_spikes['shank_id']][curr_unit_idx]
        spike_amplitude_by_block = np.nan_to_num(spike_amplitude_by_block, nan=0)
        spike_amplitude_by_block_percent = spike_amplitude_by_block/cluster_quality_by_block['session_mean_spike_amplitude'][unit_spikes['shank_id']][curr_unit_idx]# as a percentage of session mean
        unit_spikes['spike_amplitude_by_block'] = spike_amplitude_by_block_percent

        firing_rate_by_block = cluster_quality_by_block['firing_rate_by_block'][unit_spikes['shank_id']][curr_unit_idx]
        firing_rate_by_block = np.nan_to_num(firing_rate_by_block, nan=0)
        # firing_rate_by_block_percent = firing_rate_by_block/cluster_quality_by_block['session_mean_firing_rate'][unit_spikes['shank_id']][curr_unit_idx]
        if firing_rate_by_block.shape[1] == 2: 
            sleep_firing_rate = (firing_rate_by_block[0][0]*epoch_durations[0] + firing_rate_by_block[0][1]*epoch_durations[2])/np.sum(epoch_durations[[0,2]])
        elif firing_rate_by_block.shape[1] == 3:
            sleep_firing_rate = (firing_rate_by_block[0][0]*epoch_durations[0] + firing_rate_by_block[0][1]*4*3600 + firing_rate_by_block[0][2]*(epoch_durations[2]-4*3600))/np.sum(epoch_durations[[0,2]])

        if sleep_firing_rate > 0:
            firing_rate_by_block_percent = firing_rate_by_block/sleep_firing_rate
        else:
            firing_rate_by_block_percent = np.zeros((len(firing_rate_by_block),))

        unit_spikes['firing_rate_by_block'] = firing_rate_by_block_percent                

        isolation_distance_by_block = cluster_quality_by_block['isolation_distance_by_block'][unit_spikes['shank_id']][curr_unit_idx]
        isolation_distance_by_block = np.nan_to_num(isolation_distance_by_block, nan=0)
        unit_spikes['isolation_distance_by_block'] = isolation_distance_by_block


        unit_spikes['pre_post_unit_stability'] = (
            (unit_spikes['spike_amplitude_by_block'][:2] > 0.67) & 
            (unit_spikes['firing_rate_by_block'][:2] > 0.33) & 
            (unit_spikes['isolation_distance_by_block'][:2] > 15)
        ).all()
        
        # unit_spikes['unit_stability_latePOST'] = (
        #     (unit_spikes['spike_amplitude_by_block'][0][2] > 0.67) &
        #     (unit_spikes['firing_rate_by_block'][0][2] > 0.33) &
        #     (unit_spikes['isolation_distance_by_block'][0][2] > 15)
        #     ).all()

        unit_spikes['place_fields_maze']  = {}
        unit_spikes['peak_pos_bins_maze'] = {}
        # unit_spikes['peak_firing_rate'] = {}

        for direction in running_directions:
            try:
                if session_number in [6, 7]:    
                    unit_spikes['place_fields_maze'][direction] = spikes_pyr["spatialTuning_smoothed"][unit][0][direction][0][0].reshape(num_pos_bins) 
                    unit_spikes['peak_pos_bins_maze'][direction] = spikes_pyr['peakPosBin'][unit][0][direction][0][0][0][0]
                else:
                    unit_spikes['place_fields_maze'][direction] = spikes_pyr["spatialTuning_smoothed"][0][unit][direction][0][0].reshape(num_pos_bins) 
                    unit_spikes['peak_pos_bins_maze'][direction] = spikes_pyr['peakPosBin'][0][unit][direction][0][0][0][0]

            except ValueError:
                if iter == 0:
                    print("This session has only one running direction")
                iter += 1

        spikes.append(unit_spikes) 


    # interpolate the spatial tunings to be consistent with Figure2
    interp_pos_bins = np.linspace(0, num_pos_bins, 200)
    num_pos_bins_interp = len(interp_pos_bins)

    spatial_tunings_maze[session_idx] = np.zeros((num_units_total, num_pos_bins_interp))
    for unit in range(len(spikes)):
        spatial_tuning_maze_curr_unit = spikes[unit]['place_fields_maze']['uni']
        spatial_tuning_maze_curr_unit = np.interp(interp_pos_bins, np.arange(1, num_pos_bins+1), spatial_tuning_maze_curr_unit)
        spatial_tunings_maze[session_idx][unit] = spatial_tuning_maze_curr_unit

    active_units = np.where(np.nanmax(spatial_tunings_maze[session_idx], axis=1) > 1)[0]
    num_units[session_idx] = len(active_units)
    

    spikes_all_sessions[session_idx] = []
    for unit in active_units: 
        spikes_all_sessions[session_idx].append(spikes[unit])
    
    each_unit_session_number[session_idx] = np.full((num_units[session_idx],), session_number)



In [None]:
# plot the same for each individual session

def get_pval_statement(pvalue):
    if pvalue < 0.001:
        pvalue_statement = 'P<0.001'
    else:
        pvalue_statement = f'P={pvalue:.3f}'
    return pvalue_statement

# Define a function to plot violin plots
def plot_violin(ax, data, color):
    alpha = 0.8

    Q1 = np.percentile(data, 25, axis=0)
    Q3 = np.percentile(data, 75, axis=0)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR

    # Filter out the outliers
    data = data[~((data < lower_bound) | (data > upper_bound)).any(axis=1)]

    sns.violinplot(data=data, ax=ax, inner='quartiles', linewidth=0, palette=color)

    # for i, violin in enumerate(ax.collections[::2]):
    #     violin.set_facecolor(color[i])

    for violin, curr_alpha in zip(ax.collections[::1], [alpha] * 3):
        violin.set_alpha(curr_alpha)
    for l in ax.lines:
        l.set_linestyle('-')
        l.set_linewidth(0.75)
        l.set_color('white')
        l.set_alpha(1)
    for l in ax.lines[1::3]:
        l.set_linestyle('-')
        l.set_linewidth(1.5)
        l.set_color('white')
        l.set_alpha(1)

    ax.set_xlim([-0.5, 2.5])
    if data.shape[1] == 2:
        ax.set_xticklabels(['PRE', 'POST'], rotation=45, ha='center')
    elif data.shape[1] == 3:
        ax.set_xticklabels(['PRE', 'POST', 'latePOST'], rotation=45, ha='center')

    for i , label in enumerate(ax.get_xticklabels()):
        label.set_color(colors[i])

    ax.tick_params(axis='both', which='major', labelsize=7, length=2, pad=0.5)

    ax.grid(axis='y', color='gray', linewidth=1)
    for axis in ['left', 'bottom']:
        ax.spines[axis].set_linewidth(1.5)

    # Add p-values corresponding to the statistical test of difference in mean between the blocks
    data = np.nan_to_num(data, nan=0)
    significance_bar_height = np.max(data)*1.05
    p_value = np.empty((num_blocks, num_blocks), dtype=float)
    for i in range(num_blocks):
        for j in range(i+1, num_blocks):
            p_value[i,j] = wilcoxon(data[:, i],
                                data[:, j]).pvalue
            ax.plot([i,j], [significance_bar_height, significance_bar_height], lw = 1, color = 'black')
            ax.text((i+j)/2, significance_bar_height*1.01, get_pval_statement(p_value[i, j]), ha = 'center', va = "bottom", fontsize=6)
            significance_bar_height = significance_bar_height + significance_bar_height*0.1

colors = [
    '#005CE9', # PRE
    '#DD335D',  # POST
    '#DC9A5D' # late POST
    ] 

sns.set_style('whitegrid')
sns.set_context('paper')

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)

plt.rcParams['axes.linewidth'] = 1.5
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


plotheight = 800
plotwidth = 300
font_size = 6

fig = plt.figure()
fig.set_size_inches([plotwidth/72, plotheight/72])
gs = GridSpec(nrows=number_of_sessions+1, ncols=3, figure=fig, height_ratios=[1]*(number_of_sessions+1))

for session_idx in range(number_of_sessions):

    session_number = current_sessions[session_idx]
    session_name = rr[session_number]

    num_blocks = spikes_all_sessions[0][0]['spike_amplitude_by_block'].shape[1]
    num_units_current_session = num_units[session_idx]

    spike_amplitude_by_block = np.empty((num_units_current_session, num_blocks), dtype=float)
    firing_rate_by_block = np.empty((num_units_current_session, num_blocks), dtype=float)
    isolation_distance_by_block = np.empty((num_units_current_session, num_blocks), dtype=float)

    for unit in range(num_units_current_session):
        spike_amplitude_by_block[unit, :] = spikes_all_sessions[session_idx][unit]['spike_amplitude_by_block']
        firing_rate_by_block[unit, :] = spikes_all_sessions[session_idx][unit]['firing_rate_by_block']
        isolation_distance_by_block[unit, :] = spikes_all_sessions[session_idx][unit]['isolation_distance_by_block']
        

    # Plot the distribution of spike amplitude for each block/epoch 
    ax = fig.add_subplot(gs[session_idx, 0])
    plot_violin(ax, spike_amplitude_by_block*100, colors)
    ax.set_ylabel('Spike amplitude (%)', fontsize=8)
    ax.set_title(session_name, fontsize= 9, fontweight='normal', loc = 'left')
    old_ylim = ax.get_ylim()
    ax.set_ylim([0, old_ylim[1]])
    # axs[0].set_ylim([0, 300])

    # Plot the distribution of isolation distance for each block/epoch 
    ax = fig.add_subplot(gs[session_idx, 1])
    plot_violin(ax, isolation_distance_by_block, colors)
    ax.set_ylabel('Iso. distance (A.U.)', fontsize=8)
    old_ylim = ax.get_ylim()
    ax.set_ylim([0, old_ylim[1]])
    # axs[1].set_ylim([0, 200])

    # Plot the distribution of firing rate for each block/epoch 
    ax = fig.add_subplot(gs[session_idx, 2])
    plot_violin(ax, firing_rate_by_block*100, colors)
    ax.set_ylabel('Firing rate (%)', fontsize=8)
    # axs[2].set_ylim([0, 400])


# Concatenate the data from all sessions
num_unit_all_sessions = np.sum(num_units)

spike_amplitude_by_block_session_concat = np.empty((num_unit_all_sessions, num_blocks), dtype=float)
firing_rate_by_block_session_concat = np.empty((num_unit_all_sessions, num_blocks), dtype=float)
isolation_distance_by_block_session_concat = np.empty((num_unit_all_sessions, num_blocks), dtype=float)

index = 0;
for session_idx in range(number_of_sessions):
    for unit in range(num_units[session_idx]):

        spike_amplitude_by_block_session_concat[index, :] = spikes_all_sessions[session_idx][unit]['spike_amplitude_by_block']
        firing_rate_by_block_session_concat[index, :] = spikes_all_sessions[session_idx][unit]['firing_rate_by_block']
        isolation_distance_by_block_session_concat[index, :] = spikes_all_sessions[session_idx][unit]['isolation_distance_by_block']

        index += 1


 # Plot the distribution of spike amplitude for each block/epoch 
ax = fig.add_subplot(gs[number_of_sessions, 0])
plot_violin(ax, spike_amplitude_by_block_session_concat*100, colors)
ax.set_ylabel('Spike amplitude (%)', fontsize=8)
ax.set_title('Pooled', fontsize= 8, fontweight='bold', loc = 'left')
old_ylim = ax.get_ylim()
ax.set_ylim([0, old_ylim[1]])
# axs[0].set_ylim([0, 300])

# Plot the distribution of isolation distance for each block/epoch 
ax = fig.add_subplot(gs[number_of_sessions, 1])
plot_violin(ax, isolation_distance_by_block_session_concat, colors)
ax.set_ylabel('Iso. distance (A.U.)', fontsize=8)
old_ylim = ax.get_ylim()
ax.set_ylim([0, old_ylim[1]])
# axs[1].set_ylim([0, 200])

# Plot the distribution of firing rate for each block/epoch 
ax = fig.add_subplot(gs[number_of_sessions, 2])
plot_violin(ax, firing_rate_by_block_session_concat*100, colors)
ax.set_ylabel('Firing rate (%)', fontsize=8)
# axs[2].set_ylim([0, 400])


plt.tight_layout()
plt.subplots_adjust(wspace=0.5)
sns.despine()

filename = 'unit_stability_Grosmark_indiv_sessions.pdf'
file_path = os.path.join(main_dir, filename)
plt.savefig(file_path, format='pdf', dpi=300)

plt.show()



In [None]:
# Experimenting with printing on an exisiting PDF. Please ignore the code below as it is not yet functional 

In [None]:
import PyPDF2



def get_pval_statement(pvalue):
    if pvalue < 0.001:
        pvalue_statement = 'P<0.001'
    else:
        pvalue_statement = f'P={pvalue:.3f}'
    return pvalue_statement

colors = [
    '#005CE9', # PRE
    '#DD335D',  # POST
    '#DC9A5D' # late POST
    ] 

pdf_file_name = 'figure_test.pdf'

with open(pdf_file_name, "rb") as file:
    pdf_reader = PyPDF2.PdfReader(file)

    page = pdf_reader.pages[0]



    sns.set_style('whitegrid')
    sns.set_context('paper')

    custom_params = {"axes.spines.right": False, "axes.spines.top": False}
    sns.set_theme(style="ticks", rc=custom_params)

    plt.rcParams['axes.linewidth'] = 1.5
    plt.rcParams['pdf.fonttype'] = 42
    plt.rcParams['ps.fonttype'] = 42

    fig, axs = plt.subplots(nrows=1, ncols=3, sharey=False)
    fig.set_size_inches([plotwidth/72, plotheight/72])


    # Define a function to plot violin plots
    def plot_violin(ax, data, color):
        alpha = 0.8

        Q1 = np.percentile(data, 25, axis=0)
        Q3 = np.percentile(data, 75, axis=0)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        # Filter out the outliers
        data = data[~((data < lower_bound) | (data > upper_bound)).any(axis=1)]

        sns.violinplot(data=data, ax=ax, inner='quartiles', linewidth=0, palette=color)

        # for i, violin in enumerate(ax.collections[::2]):
        #     violin.set_facecolor(color[i])

        for violin, curr_alpha in zip(ax.collections[::1], [alpha] * 3):
            violin.set_alpha(curr_alpha)
        for l in ax.lines:
            l.set_linestyle('-')
            l.set_linewidth(0.75)
            l.set_color('white')
            l.set_alpha(1)
        for l in ax.lines[1::3]:
            l.set_linestyle('-')
            l.set_linewidth(1.5)
            l.set_color('white')
            l.set_alpha(1)

        ax.set_xlim([-0.5, 2.5])
        if data.shape[1] == 2:
            ax.set_xticklabels(['PRE', 'POST'], rotation=45, ha='center')
        elif data.shape[1] == 3:
            ax.set_xticklabels(['PRE', 'POST', 'latePOST'], rotation=45, ha='center')

        for i , label in enumerate(ax.get_xticklabels()):
            label.set_color(colors[i])

        ax.tick_params(axis='both', which='major', labelsize=7, length=2, pad=0.5)

        ax.grid(axis='y', color='gray', linewidth=1)
        for axis in ['left', 'bottom']:
            ax.spines[axis].set_linewidth(1.5)

        # Add p-values corresponding to the statistical test of difference in mean between the blocks
        data = np.nan_to_num(data, nan=0)
        significance_bar_height = np.max(data)*1.05
        p_value = np.empty((num_blocks, num_blocks), dtype=float)
        for i in range(num_blocks):
            for j in range(i+1, num_blocks):
                p_value[i,j] = wilcoxon(data[:, i],
                                    data[:, j]).pvalue
                ax.plot([i,j], [significance_bar_height, significance_bar_height], lw = 1, color = 'black')
                ax.text((i+j)/2, significance_bar_height*1.01, get_pval_statement(p_value[i, j]), ha = 'center', va = "bottom", fontsize=6)
                significance_bar_height = significance_bar_height + significance_bar_height*0.1
        

    # Plot the distribution of spike amplitude for each block/epoch 
    plot_violin(axs[0], spike_amplitude_by_block_session_concat*100, colors)
    axs[0].set_ylabel('Spike amplitude (%)', fontsize=8)
    # axs[0].set_ylim([0, 300])

    # Plot the distribution of isolation distance for each block/epoch 
    plot_violin(axs[1], isolation_distance_by_block_session_concat, colors)
    axs[1].set_ylabel('Iso. distance (A.U.)', fontsize=8)
    # axs[1].set_ylim([0, 200])

    # Plot the distribution of firing rate for each block/epoch 
    plot_violin(axs[2], firing_rate_by_block_session_concat*100, colors)
    axs[2].set_ylabel('Firing rate (%)', fontsize=8)
    # axs[2].set_ylim([0, 400])


    sns.despine()

    fig.tight_layout(rect=[0,0,1,1])
    fig.canvas.draw()

    pdf_figure = "figure.pdf"
    
    fig.savefig(pdf_figure, format="pdf")

    existing_pdf_stream = page.get_object()
    figure_pdf = PyPDF2.PdfReader(open(pdf_figure, 'rb')).pages[0]
    existing_pdf_stream.merge_page(figure_pdf)

    with open(pdf_file_name, "wb") as output_file:
        pdf_writer = PyPDF2.PdfWriter()
        pdf_writer.add_page(page)
        for page_num in range(1,pdf_reader.getNumPages()):
            pdf_writer.addPage(pdf_reader.getPage(page_num))
        pdf_writer.write(output_file)

    plt.close(fig)
    os.remove(pdf_figure)