In [None]:
import os
import scipy.io
from scipy.stats import ranksums
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
from learned_tuning.learned_tuning import calculate_learned_tuning, calculate_place_field_fidelity_of_learned_tuning
import pandas as pd


data_dir = r'/home/kouroshmaboudi/Documents/Learned_tuning_Python/Datasets'
sessions = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]

included_session_IDX = [x for x in range(17) if x not in (12, 13)]
# included_session_IDX = [10]

sessions = [sessions[i] for i in included_session_IDX]

for session_idx, session_name in enumerate(sessions):

    print(session_name)

    session_dataset_path = os.path.join(data_dir, session_name)
    session_number = included_session_IDX[session_idx]



    #--------------------------------------------------------------------------------------------
    # Load epochs information

    filename = f'{session_name}.fileInfo_for_python.mat'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = scipy.io.loadmat(file_path)
    session_info = mat_file["fileInfo"]

    epochs = session_info["behavior"][0][0][0][0]["time"]


    #---------------------------------------------------------------------------------------------
    # Brain state detection results

    if 0<=session_number<=4 or 6<=session_number<=10: # Grosmark and Giri datasets
        filename = f'{session_name}.brainStateDetection_HMMtheta_EMG_SWS_SchmidtTrigger.mat'
        file_path = os.path.join(session_dataset_path, filename)

        mat_file = scipy.io.loadmat(file_path)
        brainStates_bouts_label = mat_file['brainState']['bouts'][0][0][:, :-1]
        bouts_start_end = brainStates_bouts_label[:, :-1]
        bout_duration = bouts_start_end[:, 1] - bouts_start_end[:, 0] 
        bout_labels = brainStates_bouts_label[:, -1].astype(int)

        brainStates_names = []
        for i in range(4):
            brainStates_names.append(mat_file['brainState']['names'][0][0][i][0][0])

    else: # Miyawaki dataset
        filename = f'{session_name}.fileInfo_for_python.mat'
        file_path = os.path.join(session_dataset_path, filename)

        mat_file = scipy.io.loadmat(file_path)
        brainStates_bouts_label = mat_file['fileInfo']['brainStates'].item()

        bouts_start_end = brainStates_bouts_label[:, :-1]
        bout_duration = bouts_start_end[:, 1] - bouts_start_end[:, 0] 

        bout_labels = brainStates_bouts_label[:, -1].astype(int)

        # swapping the 3s and 4s, because the 3s in Hiro's brain_state_df are QWAKE while in the other datasets they are active wake(WAKE)
        bout_labels_temp = bout_labels.copy()
        bout_labels_temp[bout_labels == 3] = 4
        bout_labels_temp[bout_labels == 4] = 3
        bout_labels = bout_labels_temp[:]
        del bout_labels_temp

    brainStates_names = ['NREM', 'REM', 'WAKE', 'QWAKE', 'Undetermined']

    too_short_state_bout_index = np.where(bout_duration < 6)
    bout_labels[too_short_state_bout_index] = 5



    unique_state_labels = np.unique(bout_labels)
    num_bouts = bouts_start_end.shape[0]

    time_bins_centers = np.arange(0, bouts_start_end[-1,1]+1, 1)+0.5    
    time_bins_brain_state = np.full(time_bins_centers.shape, np.nan)

    for bout_idx, bout_timing in enumerate(bouts_start_end):
        inside_bout_idx = np.logical_and(bout_timing[0] < time_bins_centers,  time_bins_centers < bout_timing[1]) 
        time_bins_brain_state[inside_bout_idx] = bout_labels[bout_idx]


    chunk_size_in_hours = 1/60 # one minutes
    chunk_size = int(chunk_size_in_hours*3600) # in seconds
    num_chunks = len(time_bins_brain_state) // chunk_size
    chunk_start_times = [i*chunk_size_in_hours for i in range(num_chunks)] # centers of the time chunks
    
    reshaped_time_bins_brain_state = time_bins_brain_state[:num_chunks*chunk_size].reshape(num_chunks, chunk_size)

    bins = np.arange(1, unique_state_labels.shape[0]+2)
    counts = np.apply_along_axis(lambda x:np.histogram(x, bins = bins)[0] , axis = 1, arr=reshaped_time_bins_brain_state)


    # Normalize the counts to get percentages
    percentages = (counts / counts.sum(axis=1, keepdims=True))*100

    # Create a DataFrame with percentages for each category and each chunk
    brain_state_df = pd.DataFrame(percentages, columns=brainStates_names)

    
   

    #-------------------------------------------------------------------------------------------
    # Loading the population burst evenst with all their corresponding measured lfp features

    overwrite = False # in case we need to read the .mat file again, if there was a change

    filename = f'{session_name}.PBEs.pkl'
    file_path = os.path.join(session_dataset_path, filename)

    if os.path.exists(file_path) and overwrite == False:

        # PBEs = np.load(file_path, allow_pickle=True)
        PBEs = pd.read_pickle(file_path)
    else: # if it doesn't exist then read it from the .mat file

        filename = f'{session_name}.PBEInfo_replayScores_with_spindle_and_deltaPowers.mat'
        file_path = os.path.join(session_dataset_path, filename)
        f = h5py.File(file_path, "r")
        PBEInfo = f['PBEInfo_replayScores']    
       

        # Store the population burst events in a pandas DataFrame
        num_PBEs = PBEInfo["fr_1msbin"].shape[0]
        attributes = list(PBEInfo.keys())

        PBEs = pd.DataFrame(columns=[attr for attr in attributes[1:] if attr not in ['posteriorProbMat', 'postMat_nonNorm']])

        # Loop over the PBEs and add each one as a row to the DataFrame
        num_dots = int(num_PBEs * (10/100))
        count = 0

        for pbe in range(num_PBEs): #  
  
            for attr in PBEs.columns:
                ref = PBEInfo[attr][pbe][0]
                obj = f[ref]

                if attr in ['epoch', 'brainState']: # convert the ascii code to string
                    arr = np.array(obj).flatten()
                    epoch = "".join(chr(code) for code in arr)
                    PBEs.at[pbe, attr] = epoch
                elif attr in ['fr_1msbin', 'fr_20msbin', 'posteriorProbMat', 'postMat_nonNorm']: # no need to flatten
                    arr = np.array(obj)
                    PBEs.at[pbe, attr] = arr
                else: 
                    arr = np.array(obj).flatten()
                    PBEs.at[pbe, attr] = arr

        if (pbe+1) % num_dots == 1:
            count += 1
            message = "Importing PBEs" + "." * count
            print(message, end="\r")

        print("All PBEs were imported") 
    
        filename = f'{session_name}.PBEs.pkl'
        file_path = os.path.join(session_dataset_path, filename)
        PBEs.to_pickle(file_path)

    num_PBEs = PBEs.shape[0]

    # add a brain state label depending on the time bin of the PBE
    for pbe in range(num_PBEs):
        idx = np.where(np.logical_and(time_bins_centers-0.5 <= PBEs.at[pbe, 'peakT'],  time_bins_centers+0.5 > PBEs.at[pbe, 'peakT']))[0]
        PBEs.at[pbe, 'brain_state'] = time_bins_brain_state[idx]



    #----------------------------------------------------------------------------------------------
    # Load spike data
    filename = f'{session_name}.spikes_for_python.mat'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = scipy.io.loadmat(file_path)
    spikes_pyr = mat_file["spikes_pyr"]


    # Load unit stability information

    filename = f'{session_name}.cluster_quality_by_block'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = scipy.io.loadmat(file_path)
    cluster_quality_by_block = mat_file['cluster_quality_by_block'][0]



    #### Extracting all place fields from the imported .mat file
    spatial_tuning_smoothed = spikes_pyr["spatialTuning_smoothed"]

    num_units    = spatial_tuning_smoothed[0].shape[0]
    # num_units    = spatial_tuning_smoothed.shape[0] # for RatN only

    num_pos_bins = spatial_tuning_smoothed[0][0]['uni'][0][0].size

    # print(num_units, num_pos_bins)

    spikes = []; # spike data and place field info of each unit

    # attributes = list(spikes_pyr.dtype.names) % if we want to work on all variable in the imported .mat data structure
    running_directions = {'LR', 'RL', 'uni'}
    other_attributes   = {'spike_times', 'shank_id','cluster_id'}

    iter = 0
    for unit in range(num_units):
        
        # Create dictionaries for each unit and store the matrices
        
        unit_spikes = dict()
        
        unit_spikes['place_fields']  = {}
        unit_spikes['peak_pos_bins'] = {}

        
        for direction in running_directions:
            try:
                if session_number in [6,7]:
                    unit_spikes['place_fields'][direction] = spatial_tuning_smoothed[unit][0][direction][0][0].reshape(num_pos_bins) 
                    unit_spikes['peak_pos_bins'][direction] = spikes_pyr['peakPosBin'][unit][0][direction][0][0][0][0]
                else:
                    unit_spikes['place_fields'][direction] = spatial_tuning_smoothed[0][unit][direction][0][0].reshape(num_pos_bins) 
                    unit_spikes['peak_pos_bins'][direction] = spikes_pyr['peakPosBin'][0][unit][direction][0][0][0][0]


            except ValueError:
                if iter == 0:
                    print("This session has only one running direction")
                iter += 1


        if session_number in [9, 10]: # for Rat V sessions
            unit_spikes['spike_times'] = spikes_pyr['time'][0][unit]
            unit_spikes['shank_id']    = spikes_pyr['id'][0][unit][0][1]
            unit_spikes['shank_id']    += 1
            unit_spikes['cluster_id']  = spikes_pyr['id'][0][unit][0][0]

        elif session_number in [6, 7]: # for RatN and RatS
            unit_spikes['spike_times'] = spikes_pyr['time'][unit][0] 
            unit_spikes['shank_id']    = spikes_pyr['id'][unit][0][0][0]
            unit_spikes['cluster_id']  = spikes_pyr['id'][unit][0][0][1]

        elif session_number == 8: # RatU  
            unit_spikes['spike_times'] = spikes_pyr['time'][0][unit]
            unit_spikes['shank_id']    = spikes_pyr['id'][0][unit][0][0] # shank indices already starts at zero
            unit_spikes['shank_id']    += 1
            unit_spikes['cluster_id']  = spikes_pyr['id'][0][unit][0][1]
        else: # Grosmark, Hiro, and all other sessions
            unit_spikes['spike_times'] = spikes_pyr['time'][0][unit]
            unit_spikes['shank_id']    = spikes_pyr['id'][0][unit][0][0] # need to go one down for the other datasets
            unit_spikes['cluster_id']  = spikes_pyr['id'][0][unit][0][1]



        spikes.append(unit_spikes) 



    # place fields by pooling spikes across both running directions
    place_fields_uni = []
    for unit in range(num_units):
        place_fields_uni.append(spikes[unit]['place_fields']['uni'])
    place_fields_uni = np.array(place_fields_uni)

    place_fields_uni[place_fields_uni == 0] = 1e-4



    # --------------------------------------------------------------------------------------------
    # Load cluster quality data (L-ratios)
    
    filename = f'{session_name}.clusterQuality.mat'
    file_path = os.path.join(session_dataset_path, filename)

    mat_file = scipy.io.loadmat(file_path)

    # Access data structure
    num_shanks = len(mat_file["clusterQuality"]["Lratio"][0])

    L_ratios = list()
    for shank in range(num_shanks):    
        curr_shank_L_ratios = dict()
        curr_shank_L_ratios["L_ratios"] = mat_file["clusterQuality"]["Lratio"][0][shank]
        curr_shank_L_ratios["cluster_ids"] = mat_file["clusterQuality"]["clus"][0][shank]
        
        L_ratios.append(curr_shank_L_ratios)
    # L_ratios = []


    #------------------------------------------------------------------------------------------------------
    # Learned tunings during Non-REM versus Quiet Wake ripples


    # Calculate learned tunings separately for NREM and QW PBEs and calculate thier PF fidelities
    time_bin_duration= 0.02
    # num_PF_shuffles = 10000
    learned_tunings_NREM_vs_QW = {}  # create an empty dictionary to store the results


    for epoch in ['PRE', 'POST']:
        if epoch == 'PRE':
            epoch_idx = 0
            epoch_duration = epochs[0,1] - epochs[0,0]
        elif epoch == 'POST':
            epoch_idx = 2
            epoch_duration = 4*60*60
            
        # PBEs for the current epoch
        
        if_inside_epoch = PBEs['peakT'].between(epochs[epoch_idx,0], epochs[epoch_idx,0]+epoch_duration)

        # NREM PBEs


        
        select_IDX = np.where(np.logical_and(if_inside_epoch, (PBEs['brain_state'] == 1)))[0]
        epoch_NREM_PBEs = PBEs.loc[select_IDX].reset_index(drop=True)
        num_PBEs_NREM = len(epoch_NREM_PBEs)


        # QW PBEs
        select_IDX = np.where(np.logical_and(if_inside_epoch, (PBEs['brain_state'] == 4)))[0]
        epoch_QW_PBEs = PBEs.loc[select_IDX].reset_index(drop=True)
        num_PBEs_QW = len(epoch_QW_PBEs)



        # learned tunings
        # NREM
        learned_tunings_NREM = calculate_learned_tuning(epoch_NREM_PBEs, spikes, L_ratios, time_bin_duration)
        
        # learned_tuning_place_field_pearson_corr_NREM = np.full((num_units,), np.nan)

        # learned_tuning_place_field_pearson_corr_NREM, _, median_LT_PF_pearson_corr_NREM = calculate_place_field_fidelity_of_learned_tuning(learned_tunings_NREM[active_units_epochs_intersect, :], place_fields_uni[active_units_epochs_intersect, :], num_PF_shuffles)
        # The reason for why we are restricting the calculations to only active_units_epochs_intersect is for the shuffle procedure. We want to make sure that the only units with significant PFs are swapped for the each unit's PF

        # QW
        learned_tunings_QW = calculate_learned_tuning(epoch_QW_PBEs, spikes, L_ratios, time_bin_duration)

        # learned_tuning_place_field_pearson_corr_QW = np.full((num_units,), np.nan)
        # learned_tuning_place_field_pearson_corr_QW, _, median_LT_PF_pearson_corr_QW = calculate_place_field_fidelity_of_learned_tuning(learned_tunings_QW[active_units_epochs_intersect, :], place_fields_uni[active_units_epochs_intersect, :], num_PF_shuffles)
        
        # store the results in the dictionary
        learned_tunings_NREM_vs_QW[epoch] = {
            'learned_tunings_NREM': learned_tunings_NREM,
            # 'learned_tuning_place_field_pearson_corr_NREM': learned_tuning_place_field_pearson_corr_NREM,
            # 'median_LT_PF_pearson_corr_NREM': median_LT_PF_pearson_corr_NREM,
            'number_of_PBEs_NREM':num_PBEs_NREM,
            'learned_tunings_QW': learned_tunings_QW,
            # 'learned_tuning_place_field_pearson_corr_QW': learned_tuning_place_field_pearson_corr_QW,
            # 'median_LT_PF_pearson_corr_QW': median_LT_PF_pearson_corr_QW,
            'number_of_PBEs_QW':num_PBEs_QW
        }

        filename = f'{session_name}.learned_tunings_NREM_vs_QW_new_brain_state.npy'
        file_path = os.path.join(session_dataset_path, filename)
        np.save(file_path, learned_tunings_NREM_vs_QW)




    # # Plot the distributions of learned tunings for NREM and QW periods within each epoch

    # colors = sns.color_palette("husl", 2) # Set color palette
    # sns.set_style('whitegrid') # Set style and context
    # sns.set_context('paper')

    # custom_params = {"axes.spines.right": False, "axes.spines.top": False}
    # sns.set_theme(style="ticks", rc=custom_params)

    # fig, axes = plt.subplots(1,2, figsize = (10, 4))

    # for i, curr_epoch in enumerate(['PRE', 'POST']):

    #     learned_tuning_place_field_pearson_corr_NREM = learned_tunings_NREM_vs_QW[curr_epoch]['learned_tuning_place_field_pearson_corr_NREM']
    #     median_LT_PF_pearson_corr_NREM_pvalue = learned_tunings_NREM_vs_QW[curr_epoch]['median_LT_PF_pearson_corr_NREM']['p_value']
        
    #     learned_tuning_place_field_pearson_corr_QW = learned_tunings_NREM_vs_QW[curr_epoch]['learned_tuning_place_field_pearson_corr_QW']
    #     median_LT_PF_pearson_corr_QW_pvalue = learned_tunings_NREM_vs_QW[curr_epoch]['median_LT_PF_pearson_corr_QW']['p_value']

    #     num_PBEs_NREM = learned_tunings_NREM_vs_QW[curr_epoch]['number_of_PBEs_NREM']
    #     num_PBEs_QW = learned_tunings_NREM_vs_QW[curr_epoch]['number_of_PBEs_QW']


    #     # Plot the distributions using ECDFs with overlaid ticks
    #     sns.ecdfplot(learned_tuning_place_field_pearson_corr_NREM, ax = axes[i], label=f'{curr_epoch} NREM', color=colors[0], linewidth = 2)
    #     sns.ecdfplot(learned_tuning_place_field_pearson_corr_QW, ax = axes[i], label=f'{curr_epoch} QW', color=colors[1], linewidth = 2)
    #     axes[i].set_xlim([-1,1])
    #     axes[i].set_xlabel(f'{curr_epoch} LT fidelity', fontsize=10)
    #     axes[i].set_ylabel('Proportion of units', fontsize=10)
    #     axes[i].tick_params(labelsize=8)
    #     axes[i].legend(fontsize=8)


    #     # Add ticks to the x-axis
    #     y_min, y_max = axes[i].get_ylim()
    #     for x in learned_tuning_place_field_pearson_corr_NREM:
    #         axes[i].plot([x, x], [y_min+0.01, y_min + 0.02], '|-', color=colors[0], linewidth=0.25)
    #     for x in learned_tuning_place_field_pearson_corr_QW:
    #         axes[i].plot([x, x], [y_min+0.06, y_min + 0.07], '|-', color=colors[1], linewidth=0.25)

    #     # Calculate and display medians
    #     median_nrem = np.nanmedian(learned_tuning_place_field_pearson_corr_NREM)
    #     median_qw = np.nanmedian(learned_tuning_place_field_pearson_corr_QW)


    #     def get_pval_statement(pvalue):
    #         if pvalue < 0.0001:
    #             pvalue_statement = 'P<1e-4'
    #         else:
    #             pvalue_statement = f'P={pvalue:.4f}'
    #         return pvalue_statement


    #     axes[i].axvline(median_nrem, color=colors[0], linestyle='dashed', label=f'Median {curr_epoch} NREM={median_nrem:.2f},{get_pval_statement(median_LT_PF_pearson_corr_NREM_pvalue)}')
    #     axes[i].axvline(median_qw, color=colors[1], linestyle='dashed', label=f'Median {curr_epoch} QW={median_qw:.2f},{get_pval_statement(median_LT_PF_pearson_corr_QW_pvalue)}')
    #     axes[i].legend(fontsize=6)

    #     # Perform rank-sum test
    #     statistic, p_value = ranksums(learned_tuning_place_field_pearson_corr_NREM,
    #                                 learned_tuning_place_field_pearson_corr_QW)

    #     # Add line with p-value above the plot
    #     axes[i].text((median_nrem + median_qw) / 2, 0.2, f'p-value = {p_value:.4f}', ha='center', fontsize=6)

        

    #     # Bar plot of frequency of ripples in each brain states
    #     ax_inset = axes[i].inset_axes([0.2, 0.4, 0.2, 0.2])


    #     ax_inset.bar([0, 1],[num_PBEs_NREM, num_PBEs_QW], color=[colors[0], colors[1]]) #unique_strings, 
    #     ax_inset.set_xticks([0, 1], ['NREM', 'QW'])
    #     ax_inset.set_ylabel('number of PBEs', fontsize=6)
    #     ax_inset.tick_params(labelsize=6)
    #     # axes[i].tight_layout()


    # filename = f'{session_name}.learned_tunings_NREM_vs_QW_with_cluster_quality_criterion.svg'
    # file_path = os.path.join(session_dataset_path, filename)
    # plt.savefig(file_path)



In [None]:
epoch_NREM_PBEs