Created on 08/23/2021.
This script has been modified slightly for ease of use by Mariano and Avigail

In [1]:
import ast
import glob
import os
import pickle
import shutil
import sys

import numpy as np
import pandas as pd

In [2]:
def load_pickle(file):
    """Lodas the pickle file and returns a pickle object
    Args:
        file (string): labels pickle from mariano
    Returns:
        pickle: pickle contents returned
    """
    with open(file, 'rb') as fh:
        datum = pickle.load(fh)

    return datum

In [3]:
# def result_of(elec, rep, prod_or_comp, iternum, lag):
#     return results[elec][rep][prod_or_comp][iternum][lag]

## to get the correlation for: 
# electrode elec (subset of range(105)), 
# repetition rep (0 to 49) 
# prod (0) or comp (1) 
# iteration iternum (0 to 19, we have 20 iterations per rep, total 1000)repetitions)
# lag (0 to 159)

#### Slight reuse of Mariano's function (more info in the above cell)

In [4]:
def get_results(result_dict, electrode_idx, idx_str, perm_range):
    idx_of = {'prod': 0, 'comp': 1}
    results = []
    for rep in range(*perm_range):
        result_dict[electrode_idx]
        results.append(result_dict[electrode_idx][rep][idx_of[idx_str]])
    
    return np.vstack(results)

In [7]:
def pickle_to_electrode_csv(pickle_object, subject_id, shuffle_str, perm_range=None):
    subject_str = str(subject_id)
    
    electrode_pkl_name = '_'.join([subject_str, 'electrode_names.pkl'])
    electrode_pkl = load_pickle(electrode_pkl_name)

    subject_dir = os.path.join(os.getcwd(), shuffle_str, subject_str)
    os.makedirs(subject_dir, exist_ok=True)

    for electrode_idx, electrode_name in zip(electrode_pkl['electrode_id'], electrode_pkl['electrode_name']):

        if pickle_object.get(electrode_idx, None):
            ret_prod = get_results(pickle_object, electrode_idx, 'prod', perm_range) # returns 1000by160 ndarray
            ret_comp = get_results(pickle_object, electrode_idx, 'comp', perm_range)
        
        np.savetxt(os.path.join(subject_dir, f'{electrode_name}_comp.csv'), ret_comp, delimiter=',')
        np.savetxt(os.path.join(subject_dir, f'{electrode_name}_prod.csv'), ret_prod, delimiter=',')

### Significance Testing

In [8]:
from statsmodels.stats import multitest


def perform_significance_testing(project_id, shuffle_folder, noshuffle_folder, conv_flag):
    subjects = sorted(glob.glob(os.path.join(shuffle_folder, '*')))
    
    csv_save_path = os.path.dirname(shuffle_folder)

    # This is only for podcast where we ignore electrodes on right hemisphere
    hemisphere_indicator = load_pickle('podcast_hemisphere_indicator.pkl')

    lags = np.arange(-2000, 2000, 25)

    pval_dict = dict()
    some_list = []
    for subject in subjects:
        subject_key = os.path.basename(subject)

        # Load all csv files in the shuffle folder
        shuffle_elec_file_list = sorted(
            glob.glob(
                os.path.join(
                    shuffle_folder,
                    os.path.basename(subject), '*' + conv_flag + '.csv')))

        # Load all csv files in the noshuffle folder
        main_elec_file_list = sorted(
            glob.glob(
                os.path.join(
                    noshuffle_folder,
                    os.path.basename(subject), '*' + conv_flag + '.csv')))

        if project_id == 'podcast':
            curr_key = hemisphere_indicator.get(int(subject_key), None)

            if not curr_key:
                pass
            elif len(curr_key) == 2:
                shuffle_elec_file_list = list(
                    filter(lambda x: os.path.basename(x).startswith(('L', 'DL')),
                           shuffle_elec_file_list))
                main_elec_file_list = list(
                    filter(lambda x: os.path.basename(x).startswith(('L', 'DL')),
                           main_elec_file_list))
            elif len(curr_key) == 1 and 'RH' in curr_key:
                continue
            else:
                pass

        a = [os.path.basename(item) for item in shuffle_elec_file_list]
        b = [os.path.basename(item) for item in main_elec_file_list]

        assert set(a) == set(b), "Mismatch: Electrode Set"

        for elec_file1, elec_file2 in zip(shuffle_elec_file_list,
                                          main_elec_file_list):
            elecname1 = os.path.split(os.path.splitext(elec_file1)[0])[1]
            elecname2 = os.path.split(os.path.splitext(elec_file2)[0])[1]

            assert elecname1 == elecname2, 'Mismatch: Electrode Name'

            if elecname1.startswith(('SG', 'ECGEKG', 'EEGSG')):
                continue

            perm_result = pd.read_csv(elec_file1, header=None).values
            rc_result = pd.read_csv(elec_file2, header=None).values
            
            if perm_result.shape[1] != rc_result.shape[1]:
                rc_result = rc_result[0, 1:]
                
            if perm_result.shape[1] != len(lags):
                print('perm is wrong length')
            else:
                omaxs = np.max(perm_result, axis=1)

            s = 1 - (sum(np.max(rc_result) > omaxs) / perm_result.shape[0])
            some_list.append((subject_key, elecname1, s))

    df = pd.DataFrame(some_list, columns=['subject', 'electrode', 'score'])
    thresh = 0.01

    df1 = df.copy(deep=True)
    flag = np.logical_or(np.isclose(df1.score.values, thresh, atol=1e-6), df1.score.values > thresh)

    df1 = df1[flag]
    df1['electrode'] = df1['electrode'].str.strip('_' + conv_flag)
    df1.to_csv(os.path.join(csv_save_path, 'mariano_glove_pre_fdr_' + conv_flag + '.csv'),
              index=False,
              columns=['subject', 'electrode'])

    _, pcor, _, _ = multitest.multipletests(df.score.values,
                                            method='fdr_bh',
                                            is_sorted=False)

    flag = np.logical_or(np.isclose(pcor, thresh), pcor < thresh)

    df = df[flag]
    df['electrode'] = df['electrode'].str.strip('_' + conv_flag)
    df.to_csv(os.path.join(csv_save_path, 'mariano_glove_post_fdr_' + conv_flag + '.csv'), index=False, columns=['subject', 'electrode'])

    # Probably have to write a condition if it is podcast or tfs but can be safely run without issues
    if project_id == 'podcast':
        filter_hemisphere = []
        for row in df.itertuples(index=False):
            subject = row.subject
            electrode = row.electrode

            curr_key = hemisphere_indicator.get(int(subject), None)

            if not curr_key:
                if int(subject) == 798:
                    filter_hemisphere.append((subject, electrode))
            elif len(curr_key) == 2:
                if electrode.startswith(('L', 'DL')):
                    filter_hemisphere.append((subject, electrode))
            elif len(curr_key) == 1 and 'RH' in curr_key:
                continue
            else:
                filter_hemisphere.append((subject, electrode))

        df2 = pd.DataFrame(filter_hemisphere, columns=['subject', 'electrode'])
        df2.to_csv(os.path.join(csv_save_path, 'mariano_glove_post_fdr_lhp_' + conv_flag + '.csv'),
                   index=False,
                   columns=['subject', 'electrode'])

ModuleNotFoundError: No module named 'statsmodels'

### Plotting

In [7]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

plt.rcParams.update({"text.usetex": False})


def plot_results(shuffle_folder, noshuffle_folder, conv_flag):
    
    pdf_save_path = os.path.dirname(shuffle_folder)
    
    subjects = sorted(glob.glob(os.path.join(shuffle_folder, '*')))
    lags = np.arange(-2000, 2000, 25)

    pval_dict = dict()
    some_list = []
    for subject in subjects:
        subject_key = os.path.basename(subject)

        shuffle_elec_file_list = sorted(
            glob.glob(
                os.path.join(
                    shuffle_folder,
                    os.path.basename(subject), '*' + conv_flag + '.csv')))

        main_elec_file_list = sorted(
            glob.glob(
                os.path.join(
                    noshuffle_folder,
                    os.path.basename(subject), '*' + conv_flag + '.csv')))

        a = [os.path.basename(item) for item in shuffle_elec_file_list]
        b = [os.path.basename(item) for item in main_elec_file_list]

        assert set(a) == set(b), "Mismatch: Electrode Set"

        pp = PdfPages(os.path.join(pdf_save_path, 'phase_shuffle_' + conv_flag + '.pdf'))

        for elec_file1, elec_file2 in zip(shuffle_elec_file_list,
                                          main_elec_file_list):
            elecname1 = os.path.split(os.path.splitext(elec_file1)[0])[1]
            elecname2 = os.path.split(os.path.splitext(elec_file2)[0])[1]

            assert elecname1 == elecname2, 'Mismatch: Electrode Name'
            
            if elecname1.startswith(('SG', 'ECGEKG', 'EEGSG')):
                continue

            shuffle_elec_data = pd.read_csv(elec_file1, header=None)
            main_elec_data = pd.read_csv(elec_file2, header=None)

            fig, ax = plt.subplots()
#             for row in shuffle_elec_data.values:
#                 ax.plot(lags, row, linewidth=0.005, color='k', linestyle=':')
            ax.plot(lags, main_elec_data.values.T, linewidth=2, color='r')

            ax.set(xlabel='lag (s)',
                   ylabel='correlation',
                   title=elecname1)
            ax.set_ylim(-0.05, 0.35)
            ax.vlines(0, -0.05, 0.50, linestyles='dashed', linewidth=.25)

            pp.savefig(fig)
#             plt.show()
            plt.close()

        pp.close()

In [8]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

plt.rcParams.update({"text.usetex": False})


def plot_results_new(shuffle_folder, noshuffle_folder, conv_flag):
    
    pdf_save_path = os.path.join(os.path.dirname(shuffle_folder), conv_flag)
    os.makedirs(pdf_save_path, exist_ok=True)
    
    subjects = sorted(glob.glob(os.path.join(shuffle_folder, '*')))
    lags = np.arange(-2000, 2000, 25)

    pval_dict = dict()
    some_list = []
    for subject in subjects:
        subject_key = os.path.basename(subject)

        shuffle_elec_file_list = sorted(
            glob.glob(
                os.path.join(
                    shuffle_folder,
                    os.path.basename(subject), '*' + conv_flag + '.csv')))

        main_elec_file_list = sorted(
            glob.glob(
                os.path.join(
                    noshuffle_folder,
                    os.path.basename(subject), '*' + conv_flag + '.csv')))

        a = [os.path.basename(item) for item in shuffle_elec_file_list]
        b = [os.path.basename(item) for item in main_elec_file_list]

        assert set(a) == set(b), "Mismatch: Electrode Set"

        for elec_file1, elec_file2 in zip(shuffle_elec_file_list,
                                          main_elec_file_list):
            elecname1 = os.path.split(os.path.splitext(elec_file1)[0])[1]
            elecname2 = os.path.split(os.path.splitext(elec_file2)[0])[1]

            assert elecname1 == elecname2, 'Mismatch: Electrode Name'
            
            if elecname1.startswith(('SG', 'ECGEKG', 'EEGSG')):
                continue

            shuffle_elec_data = pd.read_csv(elec_file1, header=None)
            main_elec_data = pd.read_csv(elec_file2, header=None)

            fig, ax = plt.subplots()
#             for row in shuffle_elec_data.values:
#                 ax.plot(lags, row, linewidth=0.005, color='k', linestyle=':')
            ax.plot(lags, main_elec_data.values[:, :160].T, linewidth=2, color='r')

            ax.set(xlabel='lag (s)',
                   ylabel='correlation',
                   title=elecname1)
            ax.set_ylim(-0.05, 0.35)
            ax.vlines(0, -0.05, 0.50, linestyles='dashed', linewidth=.25)

            plt.savefig(os.path.join(pdf_save_path, elecname1 + '.png'))
            plt.close()

In [12]:
if __name__ == '__main__':
    subjects = [625, 676]
    shuffle_pickle = {625: 'results_114_625.pickle',
                         676: 'results_114_676.pickle'}
    
    noshuffle_pickle = {625: 'results_120_625.pickle',
                         676: 'results_120_676.pickle'}

    perm_ranges = [(0, 500), (0, 50), (50, 100), (100, 150), (150, 200), (200, 250),
                   (250, 300), (300, 350), (350, 400), (400, 450), (450, 500),
                   (0, 250), (250, 500)]

    # Save permutations from each electrode into a csv file
    # Note: For the 'noshuffle' case, I manually copied the data from elsewhere
    
    for subject in subjects:
        for idx, perm_range in enumerate(perm_ranges[1:2]):

            # Create folder for that combination
            combo_folder_name = f'{subject}_{perm_range[0]*20:05d}-to-{perm_range[1]*20:05d}'
            print(combo_folder_name)
            os.makedirs(combo_folder_name, exist_ok=True)
            
            # folder where permutations are stored
            shuffle_folder = os.path.join(os.getcwd(), combo_folder_name, 'shuffle')
            noshuffle_folder = os.path.join(os.getcwd(), combo_folder_name, 'noshuffle') 
                    
            # move info from shuffle pickle to csv files
            shuffle_pickle = load_pickle(shuffle_pickle[subject])
            pickle_to_electrode_csv(shuffle_pickle, subject, shuffle_folder, perm_range)

            # move info from no-shuffle pickle to csv files
            # FIXME: I manually copied the files from a different run
            try:
                noshuffle_pickle = load_pickle(noshuffle_pickle[subject])
                pickle_to_electrode_csv(noshuffle_pickle, subject, noshuffle_folder, (0, 1))
            except Exception:
                print('No shuffle pickle does not exist')
#                 # # copy noshuffle data to this subfolder
#                 shutil.copytree(os.path.join(os.getcwd(), 'noshuffle', str(subject)), os.path.join(noshuffle_folder, str(subject)))
            
#             perform_significance_testing('tfs', shuffle_folder, noshuffle_folder, 'prod')
#             perform_significance_testing('tfs', shuffle_folder, noshuffle_folder, 'comp')

            plot_results(shuffle_folder, noshuffle_folder, 'prod')
            plot_results(shuffle_folder, noshuffle_folder, 'comp')

625_00000-to-01000
676_00000-to-01000


KeyError: 676

#### Plotting average encoding of significant electrodes

In [None]:
import glob
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

main_folder = os.getcwd()
folder_list = sorted(glob.glob(os.path.join(main_folder, '6*_*-to-*')))

lags = np.arange(-2000, 2000, 25)

for folder in folder_list:
    print(os.path.basename(folder))
    subject_id = os.path.basename(folder)[:3]
    csv_list = sorted(glob.glob(os.path.join(folder, 'mariano*.csv')))

    for f in glob.glob(os.path.join(folder, '*.png')):
        os.remove(f)

    for csv_file in csv_list:
        csv_contents = pd.read_csv(csv_file)
        csv_file_name = os.path.basename(csv_file)
        csv_file_name = os.path.splitext(csv_file_name)[0]
        print(csv_file_name)

        conv_flag = csv_file_name.split('_')[-1]

        encoding = []
        for electrode in csv_contents.electrode:
            electrode_path = os.path.join(folder, 'noshuffle', subject_id,
                                          electrode)
            try:
                with open(electrode_path + '_' + conv_flag + '.csv', 'r') as fh:
                    my_data = np.genfromtxt(fh,
                                        delimiter=',')
            except Exception:
                with open(electrode_path +  '.csv', 'r') as fh:
                    my_data = np.genfromtxt(fh,
                                        delimiter=',')

            encoding.append(my_data)
        encoding = np.vstack(encoding)
        mean_encoding = np.mean(encoding, axis=0)
        fig, ax = plt.subplots()
        ax.plot(lags, mean_encoding[:160], linewidth=1, color='r')

        ax.set(xlabel='lag (s)', ylabel='correlation')
        ax.set_ylim(-0.05, 0.250)
        ax.vlines(0, -0.05, 0.50, linestyles='dashed', linewidth=.25)

        plt.savefig(os.path.join(folder, csv_file_name + '.png'))
        plt.close()

#### Plot no-shuffle average encoding for glove

In [None]:
import glob
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

main_folder = os.getcwd()
folder_list = sorted(glob.glob(os.path.join(main_folder, '6*_*-to-*')))

lags = np.arange(-2000, 2000, 25)

for folder in folder_list:
    print(os.path.basename(folder))
    subject_id = os.path.basename(folder)[:3]
    
    noshuffle_folder = os.path.join(folder, 'noshuffle', subject_id)

    prod_elec_list = sorted(glob.glob(os.path.join(noshuffle_folder, '*_prod.csv')))
    comp_elec_list = sorted(glob.glob(os.path.join(noshuffle_folder, '*_comp.csv')))
        
    prod_encoding = []
    for csv_file in prod_elec_list:
        my_data = np.genfromtxt(csv_file, delimiter=',')
        prod_encoding.append(my_data)
    prod_encoding = np.vstack(prod_encoding)
    mean_prod_encoding = np.mean(prod_encoding, axis=0)

    comp_encoding = []
    for csv_file in comp_elec_list:
        my_data = np.genfromtxt(csv_file, delimiter=',')
        comp_encoding.append(my_data)
    comp_encoding = np.vstack(comp_encoding)
    mean_comp_encoding = np.mean(comp_encoding, axis=0)

    fig, ax = plt.subplots()
    ax.plot(lags, mean_prod_encoding[:160], linewidth=1, color='r', label='Production')
    ax.plot(lags, mean_comp_encoding[:160], linewidth=1, color='b', label='Comprehension')
    
    ax.set(xlabel='lag (s)', ylabel='correlation')
    ax.set_ylim(-0.05, 0.250)
    ax.vlines(0, -0.05, 0.50, linestyles='dashed', linewidth=.25)
    ax.legend()

    plt.savefig(os.path.join(folder, 'average_encoding.png'))
    plt.close()