In [None]:
import os
import pandas as pd
import numpy as np
from scipy.io import loadmat

from scipy.stats import permutation_test
from scipy.stats import ttest_1samp
from scipy.stats import ttest_ind
from statsmodels.stats import multitest

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import seaborn as sns
from tfsplt_utils import load_pickle

Code that generates all plots for the manuscript "Deep speech-to-text models capture the neural basis of spontaneous speech production and comprehension in everyday conversations"

Authors: Leonard Niekerken, Ken Wang, Bobbi Aubrey;
June 2023

README

First, make sure you have all necessary data files and that the paths specified under PARAMETERS point to the right location. TODO Organize the file structure

Required files: 
1) base_df.csv - containing 1) subject ids (sid), 2) electrode names for all patients, in two naming conventions (elec_1 and elec_2), 3) electrode coordinates in MNI space (x,y,z) and electrode type (type), 4) roi information; NYU_roi refers to the original classification dony by NYU, roi_1 to the manual adaption done by us (selecting the area with the most percentage per ROI) and roi_2 further combines preCG and postCG into SM (sensorimotor), with subdivisions into dSM, mSM and vSM, 5) a column for each significance test done by Google,indicating whether given electrode is significant or not
2) example_electrodes.csv -  containing electrodes selected as example electrodes for figure 4 and 6; 1) subject ids (sid), 2) subject ids used in the figure (sid_2), 3) electrode names (elec_1), 4) roi as in our adaptation (roi), 5) which scale to use for plotting (scale), 6) columns indicating whether given electrode is significant in tests done by Google
3) ch2_template_mni_lh_pial.mat - containing anatomical information to plot the brain template for brain maps (figure 2, 3, 4, 5, 6)
4) encoding_results/whisper-tiny.en-[encoder/decoder/full] - containing encoding results for all electrodes for different embeddings 

The code is organized as follows: 
1) PARAMETERS BLOCK - 1) in- and output paths, 2) data importing parameters, 3) plotting style settings are specified
2) FUNCTIONS BLOCK - contains all functions that are required to run the notebook
3) a LOAD DATUM BLOCK, where the base_df is loaded
4) a BLOCK for each FIGURE

Just run each block in this order to generate plots and statistical test results corresponding to each figure in the manuscript.

Images will be saved to the /results folder as .svg or .png files. Statistical test results will be saved as .csv files.

In [None]:
## PARAMETERS ##

## filestructure parameters

# path to load data from
PATH = '/scratch/gpfs/ln1144/247-plotting/paper'
BASE_PATH = '/scratch/gpfs/ln1144/247-plotting/paper/data/base_df.csv'
DATA_PATH = '/scratch/gpfs/ln1144/247-plotting/paper/data/encoding_results'
COOR_PATH = '/scratch/gpfs/ln1144/247-plotting/paper/data/brainmap/ch2_template_mni_lh_pial.mat'
OUT_PATH = '/scratch/gpfs/ln1144/247-plotting/paper/results'

## data parameters

# encoding data is from -10 to 10 seconds (relative to word onset) - don't change this!
LAGS_START = -10000
LAGS_STOP = 10000
# define the time window to use for all plots and statistical analyses 
PLT_START = -2000
PLT_STOP = 2000
# sliding window that was used during encoding analysis - don't change this!
STEPS = 25
LAGS = np.linspace(LAGS_START,LAGS_STOP,int((LAGS_STOP-LAGS_START)/STEPS+1))
LAGS_SHOW = np.linspace(PLT_START,PLT_STOP,int((PLT_STOP-PLT_START)/STEPS+1))
X_VALS_SHOW = LAGS_SHOW

x_vals_show = [x_val / 1000 for x_val in X_VALS_SHOW]
lags_show = [lag / 1000 for lag in LAGS_SHOW]

## plotting parameters

# plt.style.use('/scratch/gpfs/ln1144/247-plotting/scripts/paper.mlpstyle')
ls = '-'
lw = 2

In [None]:
## FUNCTIONS ##

# filter datum on significant electrodes for a given emb
def filter_datum(emb, df):

    df = df[df[emb] == True]

    return df

# load encoding results for a given embedding, layer, subject, electrode, mode and filter by the lags you want to show in the plot
def load_elec_data(emb,layer,sid,elec,mode):

    fid = f"whisper-tiny.en-{emb}"
    fpath = f"/{fid}/{fid}-{sid}-lag10k-25-all-{layer}"
    fname = f"/{sid}_{elec}_{mode}.csv"

    df = pd.read_csv((DATA_PATH+fpath+fname))

    # filter lags
    df['lags'] = LAGS
    df = df[df.lags.isin(LAGS_SHOW)]
    
    df['lags_show'] = lags_show
    df.set_index(df.lags_show,inplace=True)

    return df

# get encoding mean and se for a given embedding, layer, subject, electrode, mode
def get_encoding(emb,layer,sid,elec,mode):

    df = load_elec_data(emb,layer,sid,elec,mode)

    return df.avg, df.se

# get encoding mean for a given embedding, layer, subject, electrode, mode
def get_encoding_avg(emb,layer,sid,elec,mode):

    df = load_elec_data(emb,layer,sid,elec,mode)

    return df.avg

# get encoding se for a given embedding, layer, subject, electrode, mode
def get_encoding_sem(emb,layer,sid,elec,mode):

    df = load_elec_data(emb,layer,sid,elec,mode)

    return df.sem

# get encoding for all folds for a given embedding, layer, subject, electrode, mode
def get_encoding_folds(emb,layer,sid,elec,mode):

    df = load_elec_data(emb,layer,sid,elec,mode)

    return df.iloc[:,:10]

# get maximal encoding for a given embedding, layer, subject, electrode, mode
def get_encoding_max(emb,layer,sid,elec,mode):

    df = load_elec_data(emb,layer,sid,elec,mode)

    return df.avg.max()

# get classifier bar plot results
def get_classifier_results(filename, class_type, class_type_plt, class_cat):
    df = pd.read_csv(filename)

    # whole = df.loc[:, "10"].tolist()
    df.drop(columns=["10"], inplace=True)
    means = df.mean(axis=1).tolist()
    stds = df.std(axis=1).tolist()

    results_df = pd.DataFrame(
        {
            "balanced accuracy": means,
            "std": stds,
            "class_type": class_type,
            "class_cat": class_cat,
        }
    )

    results_df = results_df[results_df.class_type.isin(class_type_plt)]

    return results_df



# compute statistic (difference of the means) for paired samples permutation test
def statistic(x,y,axis):
    return np.mean(x,axis=axis) - np.mean(y,axis=axis)

# run paired samples permutation test
def run_permutation_test(x,y,alpha):

    res = permutation_test((x,y),statistic,vectorized=True,n_resamples=10000,alternative='two-sided',permutation_type='samples')
    p_vals = res.pvalue
    q_vals = multitest.multipletests(p_vals,method="fdr_bh", is_sorted=False, alpha=alpha)

    return q_vals

# plot t-SNE results
def plot_tsne(df, x, y, color, title):
    colors_distinct = [
        "#000000",
        "#00FF00",
        "#0000FF",
        "#FF0000",
        "#01FFFE",
        "#FFA6FE",
        "#774D00",
        "#006401",
        "#010067",
        "#95003A",
        "#007DB5",
        "#FF00F6",
        "#FFEEE8",
        "#FFDB66",
        "#90FB92",
        "#0076FF",
        "#D5FF00",
        "#FF937E",
        "#6A826C",
        "#FF029D",
        "#FE8900",
        "#7A4782",
        "#7E2DD2",
        "#85A900",
        "#FF0056",
        "#A42400",
        "#00AE7E",
        "#683D3B",
        "#BDC6FF",
        "#263400",
        "#BDD393",
        "#00B917",
        "#9E008E",
        "#001544",
        "#C28C9F",
        "#FF74A3",
        "#01D0FF",
        "#004754",
        "#E56FFE",
        "#788231",
        "#0E4CA1",
        "#91D0CB",
        "#BE9970",
        "#968AE8",
        "#BB8800",
        "#43002C",
        "#DEFF74",
        "#00FFC6",
        "#FFE502",
        "#620E00",
        "#008F9C",
        "#98FF52",
        "#7544B1",
        "#B500FF",
        "#00FF78",
        "#FF6E41",
        "#005F39",
        "#6B6882",
        "#5FAD4E",
        "#A75740",
        "#A5FFD2",
        "#FFB167",
        "#009BFF",
        "#E85EBE",
    ]
    df2 = df.copy()
    g = df.groupby(df[color])
    df2 = g.filter(lambda x: len(x) >= 20)
    df2["freq"] = df2.groupby(df2[color])[color].transform("count")
    df2.sort_values("freq",inplace=True,ascending=False)
    # plt.style.use("/scratch/gpfs/ln1144/247-plotting/scripts/paper.mlpstyle")
    sns.scatterplot(data=df2,x=df2[x],y=df2[y],hue=df2[color],palette=colors_distinct[0:len(df2[color].unique())], linewidth=0,style=df2["marker"],s=5, markers=["o"])
    plt.title(f"{title}")
    # plt.show()
    plt.savefig(f"../{title}.jpeg")
    plt.close()
    return

# plot classifier bar plots
def plot_classifier_bar(results_df, filename, colors):
    dfp = results_df.pivot(
        index="class_cat", columns="class_type", values="balanced accuracy"
    )
    yerr = results_df.pivot(index="class_cat", columns="class_type", values="std")
    dfp.plot(
        kind="bar",
        yerr=yerr,
        rot=0,
        color=colors,
        error_kw=dict(ecolor="black", elinewidth=1, capsize=1),
    )

    plt.savefig(filename)
    plt.close()

    return


# plot classic encoding plot, with average encoding per lag (vals), se per lag (errs), significant lags (siglags) vals, errs, colors are lists containing the lists of average values/ 
# se to plot and colors to use; if len(vals) > 1, multiple lines will be plotted (which is the case for all encoding plots in the manuscript)
def plot_encoding(vals,errs,siglags,colors,scale,title):

    plt.style.use('/scratch/gpfs/ln1144/247-plotting/scripts/paper.mlpstyle')

    fig, ax = plt.subplots()

    ax.axhline(0, ls="dashed", alpha=0.3, c="k")
    ax.axvline(0, ls="dashed", alpha=0.3, c="k")

    if scale == '0.5':
        ax.set_ylim(-0.05, 0.5) 
        ax.set_yticks([0,0.250,0.5])
        ax.set_yticklabels(['0.000','0.25','0.5'])
    elif scale == '0.25':
        ax.set_ylim(-0.05, 0.25) 
        ax.set_yticks([0,0.125,0.25])
    elif scale == '0.3':
        ax.set_ylim(-0.05, 0.3) 
        ax.set_yticks([0,0.125,0.25])
    elif scale == '0.35':
        ax.set_ylim(-0.05, 0.35) 
        ax.set_yticks([0,0.15,0.3]) 
    elif scale == '0.4':
        ax.set_ylim(-0.05, 0.4) 
        ax.set_yticks([0,0.15,0.3]) 

    ax.set(xlabel="Lag (s)", ylabel="Correlation (r)")
    ax.set_title(title, weight = "bold")

    for val,err,color in zip(vals,errs,colors):
        ax.fill_between(x_vals_show, val - err, val + err, alpha=0.2, color=color)
        ax.plot(x_vals_show, val, color=color, ls=ls, lw=lw)

    # plot significance asterisks
    yheight = ax.get_ylim()[1] - 0.005
    ax.scatter(np.asarray(x_vals_show)[siglags],[yheight] * len(siglags),marker="*",color="grey",s=0.25)

# plot encoding plot, including scatters indicating peak onsets (argmaxs); yheights defines the y-position of those scatters in the plot
def plot_temporal_encoding(vals,errs,argmaxs,yheights,colors,scale,title):

    plt.style.use('/scratch/gpfs/ln1144/247-plotting/scripts/paper.mlpstyle')

    fig, ax = plt.subplots()

    for val, err, argmax, yheight, color in zip(vals,errs,argmaxs,yheights,colors):
        ax.fill_between(x_vals_show, val - err, val + err, alpha=0.2, color=color)
        ax.plot(x_vals_show, val, color=color, ls=ls, lw=lw)
        ax.scatter(argmax,[yheight]*len(argmax),marker="*",color=color,s=0.25)
        ax.scatter(np.asarray(np.mean(argmax)),[yheight],marker="D",color=color,s=5)

    ax.axhline(0, ls="dashed", alpha=0.3, c="k")
    ax.axvline(0, ls="dashed", alpha=0.3, c="k")

    if scale == '0.5':
        ax.set_ylim(-0.05, 0.5) 
        ax.set_yticks([0,0.250,0.5])
        ax.set_yticklabels(['0.000','0.25','0.5'])
    elif scale == '0.25':
        ax.set_ylim(-0.05, 0.25) 
        ax.set_yticks([0,0.125,0.25])
    elif scale == '0.3':
        ax.set_ylim(-0.05, 0.3) 
        ax.set_yticks([0,0.125,0.25])
    elif scale == '0.35':
        ax.set_ylim(-0.05, 0.35) 
        ax.set_yticks([0,0.15,0.3]) 
    elif scale == '0.4':
        ax.set_ylim(-0.05, 0.4) 
        ax.set_yticks([0,0.15,0.3]) 

    ax.set(xlabel="Lag (s)", ylabel="Correlation (r)")
    ax.set_title(title, weight = "bold")

# loading brain surface plot (for now only for one hemisphere)
def load_surf(fpath):
          
    surf1 = loadmat(fpath)
    surf2 = []

    return surf1, surf2

# plot 3D brain
def plot_brain(surf1, surf2):

    # surf["faces"] is an n x 3 matrix of indices into surf["coords"]; connectivity matrix
    # subtract 1 from every index to convert MATLAB indexing to Python indexing
    surf1["faces"] = np.array([conn_idx - 1 for conn_idx in surf1["faces"]])

    # plot 3D surfact plot of brain, colored according to depth
    fig = go.Figure()

    fig.add_trace(go.Mesh3d(x=surf1["coords"][:,0], y=surf1["coords"][:,1], z=surf1["coords"][:,2],
                     i=surf1["faces"][:,0], j=surf1["faces"][:,1], k=surf1["faces"][:,2],
                     color='rgb(175,175,175)'))
    
    # if both hemispheres
    if surf2:
        surf2["faces"] = np.array([conn_idx - 1 for conn_idx in surf2["faces"]])

        fig.add_trace(go.Mesh3d(x=surf2["coords"][:,0], y=surf2["coords"][:,1], z=surf2["coords"][:,2],
                          i=surf2["faces"][:,0], j=surf2["faces"][:,1], k=surf2["faces"][:,2],
                          color="rgb(175,175,175)"))

    fig.update_traces(lighting_ambient=0.3)
    return fig

# plot 3D electrodes on the brain map
def plot_electrodes(df,cbar_title,colorscale):

    r = 1.5
    fignew = go.Figure()

    for index,row in df.iterrows(): 
        u, v = np.mgrid[0:2*np.pi:26j, 0:np.pi:26j]
        x = r * np.cos(u)*np.sin(v) + row.x
        y = r * np.sin(u)*np.sin(v) + row.y
        z = r * np.cos(v) + row.z

        fignew.add_trace(go.Surface(x=x,y=y,z=z,surfacecolor=np.ones(shape=z.shape),name=row.elec_2,
                      legendgroup=cbar_title,colorscale=colorscale))
    
    return fignew
    
# set min/max of colorbar for brain map
def scale_colorbar(fignew, df, cbar_min, cbar_max, cbar_title):

    if cbar_min is not None:
        cmin = cbar_min
    else:
        cmin = df["effect"].min()

    if cbar_max is not None:
        cmax = cbar_max
    else:
        cmax = df["effect"].max()
    
    fignew.update_traces(cmin=cmin,cmax=cmax,colorbar_title=cbar_title,
                         colorbar_title_font_size=40,colorbar_title_side='right')
    
    return fignew
    
# color electrodes on brain map according to effect
def electrode_colors(fignew, df, subset):

    # Once max, min of colorbar is set, you can just use the value you want to plot (e.g. correlation) to determine the coloring,
    # must be in array the same shape as z data
    if subset > 0:
        fignew.update_traces(colorbar_x = 1 + 0.2*subset)

    for index,row in df.iterrows():
        effect = abs(row.effect)
        fignew.data[index]["surfacecolor"] = fignew.data[index]["surfacecolor"] * effect

    return fignew

# add figure properties for the brain map
def update_properties(fig):

    # Left hemisphere
    # TODO: add camera for other views
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=-1.5, y=0, z=0)
    )

    scene = dict(
        xaxis = dict(visible=False),
        yaxis = dict(visible=False),
        zaxis = dict(visible=False),
        aspectmode='auto'
    )

    fig.update_layout(scene_camera=camera,scene=scene)
    fig.update_traces(lighting_specular=0.4,colorbar_thickness=40,colorbar_tickfont_size=30,
                    lighting_roughness=0.4,lightposition=dict(x=0, y=0, z=100))

    return fig

# plot brain map
def plot_brainmap(effect_dfs,cbar_titles,cbar_min,cbar_max,colorscales,out_path):

    surf1, surf2 = load_surf(COOR_PATH)
    fig = plot_brain(surf1, surf2)

    for subset, cbar_title in enumerate(cbar_titles):
        
        if colorscales is None:
            colorscale = [[0,'rgb(255,0,0)'], [1,'rgb(255,255,0)']]
        else:
            colorscale = colorscales[cbar_title] 
    
        effect_df = effect_dfs[cbar_title]
        
        fignew = plot_electrodes(effect_df,cbar_title,colorscale)   
        fignew = scale_colorbar(fignew, effect_df, cbar_min, cbar_max, cbar_title)
        fignew = electrode_colors(fignew, effect_df, subset)
        
        # Add electrode traces to main figure
        for trace in range(0,len(fignew.data)):
            fig.add_trace(fignew.data[trace])

    fig = update_properties(fig)

    fig.write_image(out_path, scale=6, width=1200, height=1000)

Load base datum

In [None]:
# load base df that contains all information about electrodes
base_df = pd.read_csv(BASE_PATH)

Figure 1

In [None]:
## FIGURE 1 ##

# TODO brainmap per patient

Figure 2

In [None]:
## FIGURE 2 ##

# t-SNE Plots
# tsne_df = load_pickle(f"/scratch/gpfs/kw1166/247-plotting/results/20230612-whisper-tsne-no-filter/all4-whisper-tsne-ave.pkl")
# tsne_df = load_pickle(f"/scratch/gpfs/kw1166/247-plotting/results/20240131-podcast-rep-for-daria/pod-whisper-tsne-ave.pkl")
tsne_df = load_pickle(f"/scratch/gpfs/kw1166/247-plotting/results/20240212-podcast-pkl-from-daria/pod-whisper-tsne-ave.pkl")
# pca_file = "/scratch/gpfs/kw1166/247-plotting/results/paper-whisper/all4-whisper-pca-ave.pkl"
# tsne_df = pd.read_pickle(pca_file)
tsne_df["marker"] = 1

plot_dict = {
    "pho": "phoneme",
    "part_of_speech": "part_of_speech",
    # "place_artic": "place_artic",
    # "manner_artic": "manner_artic",
}

for cat in plot_dict.keys():
    plot_tsne(tsne_df, "en_x", "en_y", cat, f"speech-{plot_dict[cat]}")
    plot_tsne(tsne_df, "de_x", "de_y", cat, f"language-{plot_dict[cat]}")


# # Get dataframe for classification results
# class_type = [ # types of classifier (do not change)
#     "1speech",
#     "2language",
#     "3control",
#     "uniform",
#     "strat",
# ] * 4
# class_type_plt = ["1speech", "2language", "3control"] # types to plot (can change)
# class_cat = np.repeat(["1Phoneme", "2PoA", "3MoA", "4PoS"], len(class_type) / 4) # categories

# class_results_df = get_classifier_results(
#     "../results/20230612-whisper-tsne-no-filter/classifier_pca50_filter-100_ave_L.csv",
#     class_type,
#     class_type_plt,
#     class_cat,
# )

# # Plot classifier bar plots
# class_bar_colors = ["red", "blue", "grey"]
# class_bar_name = "../results/20230612-whisper-tsne-no-filter/barplot.svg"
# plot_classifier_bar(class_results_df, class_bar_name, class_bar_colors)

Figure 3

In [None]:
## FIGURE 3 ##

# TODO add chisquare test

MODELS = ['whisper-en-last-0.01-prod','whisper-en-last-0.01-comp','whisper-de-best-0.01-prod','whisper-de-best-0.01-comp']

for model in MODELS:

    # select electrodes

    effect_df = filter_datum(model, base_df)
    effect_df.reset_index(drop=True,inplace=True)

    effect = []

    mode=model[-4:]

    if 'en-last' in model:
        emb = 'encoder'
        l = 4
    elif 'de-best' in model:
        emb = 'decoder'
        l=3

    # aggregate data

    for index,row in effect_df.iterrows():
        effect.append(get_encoding_max(emb,l,row.sid,row.elec_1,mode))

    effect_df['effect'] = effect

    fpath = "/figure_3"
    fname = f"/figure_3_brainmap_{emb}_{mode}.png"

    # plotting
    
    plot_brainmap(effect_df,cbar_titles=['correlation'],cbar_min=0.04,cbar_max=0.4,colorscales=None,out_path=(OUT_PATH+fpath+fname))

Figure 4

In [None]:
## FIGURE 4 ##

MODELS = ['whisper-en-last-whisper-de-best-contrast-0.01-prod', 'whisper-en-last-whisper-de-best-contrast-0.01-comp']

for model in MODELS:

    ## plot brainmap

    # select electrodes

    effect_df = filter_datum(model, base_df)
    effect_df.reset_index(drop=True,inplace=True)
    effect_df = effect_df.iloc[:,:7]

    effect = []

    emb1 = 'encoder'
    l1 = 4
    emb2 = 'decoder'
    l2 = 3

    mode = model[-4:]

    # aggregate data

    for index, row in effect_df.iterrows():
        effect.append(get_encoding_max(emb1,l1,row.sid,row.elec_1,mode)-get_encoding_max(emb2,l2,row.sid,row.elec_1,mode))

    effect_df['effect'] = effect

    effect_df_list = []
    effect_df_list.append(effect_df[effect_df.effect>0.01].reset_index(drop=True))
    effect_df_list.append(effect_df[effect_df.effect<0.01].reset_index(drop=True))

    cbar_min=0.01
    cbar_max=0.15
    cbar_titles = ['a','b']
    colorscales = {cbar_titles[0]:[[0,"#ffa07a"],[1,"#ff0000"]],cbar_titles[1]:[[0,"#87ceff"],[1,"#0000ff"]]}
    effect_dfs = {cbar_titles[0]:effect_df_list[0],cbar_titles[1]:effect_df_list[1]}

    fpath = "/scratch/gpfs/kw1166/247-plotting/"
    fname = f"/figure_4_brainmap_{mode}.png"

    # plotting
    
    plot_brainmap(effect_dfs,cbar_titles,cbar_min,cbar_max,colorscales,out_path=(fpath+fname))

    ## plot encoding for example electrodes

    # select electrodes

    # fname = '/data/example_electrodes.csv'
    # example_elec_df = pd.read_csv((PATH+fname))
    # example_elec_df = filter_datum(model,example_elec_df)

    # for index, row in example_elec_df.iterrows():

    #     # aggregate data 

    #     vals = [0] * 2
    #     errs = [0] * 2
    #     vals[0], errs[0] = get_encoding(emb1,l1,row.sid,row.elec_1,mode)
    #     vals[1], errs[1] = get_encoding(emb2,l2,row.sid,row.elec_1,mode)

    #     folds1 = get_encoding_folds(emb1,l1,row.sid,row.elec_1,mode)
    #     folds2 = get_encoding_folds(emb2,l2,row.sid,row.elec_1,mode)

    #     x_allfolds = np.transpose(np.array(folds1))
    #     y_allfolds = np.transpose(np.array(folds2))

    #     # stat test

    #     q_vals = run_permutation_test(x_allfolds,y_allfolds,0.05)
    #     siglags = q_vals[0].nonzero()[0]

    #     # plotting

    #     colors = [(1,0,0), (0,0,1)]
    #     scale = str(row.scale)
    #     title = row.roi + ' (' + row.sid_2 + ')'
        
    #     plot_encoding(vals,errs,siglags,colors,scale,title)

    #     fpath = "/figure_3"
    #     fname = f"/figure_3_{mode}_{row.elec_1}.svg"
        
    #     plt.savefig((OUT_PATH+fpath+fname)) 

Figure 5

In [None]:
## FIGURE 5 ##

## panels A and B

MODES = ['prod','comp']

ROIS = ['IFG','SM','STG']
SM = ['preCG', 'postCG']

yheights = [0.33, 0.31, 0.29]
colors = [(0,0,1),(253/255,141/255,60/255),(1,0,0)]
scale = '0.3'
title = ''

res_list = []

for mode in MODES:

    avgs1 = []
    avgs2 = []
    avgs3 = []

    for roi in ROIS:

        # select electrodes

        if roi == 'IFG':
            model = f"whisper-de-best-0.01-{mode}"
            emb = 'decoder'
            l = 3
            df = filter_datum(model, base_df)
            roi_df = df[df.roi_1 == roi]

            for index, row in roi_df.iterrows():

                avgs1.append(get_encoding_avg(emb, l, row.sid, row.elec_1, mode))

        elif roi == 'SM':
            model = f"whisper-en-last-0.01-{mode}"
            emb = 'encoder'
            l = 4
            df = filter_datum(model, base_df)
            roi_df = df[df.roi_1.isin(SM)]

            for index, row in roi_df.iterrows():

                avgs2.append(get_encoding_avg(emb, l, row.sid, row.elec_1, mode))

        elif roi == 'STG':
            model = f"whisper-en-last-0.01-{mode}"
            emb = 'encoder'
            l = 4
            df = filter_datum(model, base_df)
            roi_df = df[df.roi_1 == roi]

            for index, row in roi_df.iterrows():

                avgs3.append(get_encoding_avg(emb, l, row.sid, row.elec_1, mode))

    # aggregate data

    avgs1_df = pd.DataFrame(avgs1)
    avgs2_df = pd.DataFrame(avgs2)
    avgs3_df = pd.DataFrame(avgs3)

    vals = [] 
    vals.append(avgs1_df.mean(axis=0).tolist())
    vals.append(avgs2_df.mean(axis=0).tolist())
    vals.append(avgs3_df.mean(axis=0).tolist())

    errs = []
    errs.append(avgs1_df.sem(axis=0, ddof=0))
    errs.append(avgs2_df.sem(axis=0, ddof=0))
    errs.append(avgs3_df.sem(axis=0, ddof=0))

    maxlags = []
    maxlags.append(avgs1_df.idxmax(axis=1))
    maxlags.append(avgs2_df.idxmax(axis=1))
    maxlags.append(avgs3_df.idxmax(axis=1))

    # plotting

    plot_temporal_encoding(vals,errs,maxlags,yheights,colors,scale,title)

    fpath = "/figure_5"
    fname = f"/figure_5_{mode}.svg"

    plt.savefig((OUT_PATH+fpath+fname)) 

    # statistical test

    dict_1 = {}
    dict_2 = {}
    dict_3 = {}

    dict_1['roi'] = 'IFG'
    dict_1['mode'] = mode
    dict_1['mean'] = maxlags[0].mean()
    dict_1['SD'] = maxlags[0].std()

    dict_2['roi'] = 'SM'
    dict_2['mode'] = mode
    dict_2['mean'] = maxlags[1].mean()
    dict_2['SD'] = maxlags[1].std()

    dict_3['roi'] = 'STG'
    dict_3['mode'] = mode
    dict_3['mean'] = maxlags[2].mean()
    dict_3['SD'] = maxlags[2].std()

    if mode == 'prod':

        # 1) IFG < preCG
        res = ttest_ind(maxlags[0],maxlags[1],axis=0,alternative='less')
        dict_1['test'] = 'prod: IFG<SM'
        dict_1['t_statistic'] = res.statistic
        dict_1['p_value'] = res.pvalue
        dict_1['df'] = (len(maxlags[0])+len(maxlags[1]))-2
        # dict_1['perm_res'] = perm_test(argmax_1, argmax_2)
        res_list.append(dict_1)

        # 2) PreCG < STG
        res = ttest_ind(maxlags[1],maxlags[2],alternative='less')
        dict_2['test'] = 'prod: SM<STG'
        dict_2['t_statistic'] = res.statistic
        dict_2['p_value'] = res.pvalue
        dict_2['df'] = (len(maxlags[1])+len(maxlags[2]))-2  
        # dict_2['perm_res'] = perm_test(argmax_2, argmax_3)
        res_list.append(dict_2)

        res = ttest_ind(maxlags[0],maxlags[2],alternative='less')
        dict_3['test'] = 'prod: IFG<STG'
        dict_3['t_statistic'] = res.statistic
        dict_3['p_value'] = res.pvalue
        dict_3['df'] = (len(maxlags[0])+len(maxlags[2]))-2  
        # dict_3['perm_res'] = perm_test(argmax_1, argmax_3)
        res_list.append(dict_3)

    elif mode == 'comp':

        # 1) SM < STG
        res = ttest_ind(maxlags[1],maxlags[2],alternative='less')
        dict_1['test'] = 'comp: SM<STG'
        dict_1['t_statistic'] = res.statistic
        dict_1['p_value'] = res.pvalue
        dict_1['df'] = (len(maxlags[1])+len(maxlags[2]))-2
        # dict_1['perm_res'] = perm_test(argmax_2, argmax_3)
        res_list.append(dict_1)

        # 2) STG < IFG
        res = ttest_ind(maxlags[2],maxlags[0],alternative='less')
        dict_2['test'] = 'comp: STG<IFG'
        dict_2['t_statistic'] = res.statistic
        dict_2['p_value'] = res.pvalue
        dict_2['df'] = (len(maxlags[2])+len(maxlags[0]))-2  
        # dict_2['perm_res'] = perm_test(argmax_3, argmax_1)
        res_list.append(dict_2)

        # 3) preCG < IFG
        res = ttest_ind(maxlags[1],maxlags[0],alternative='less')
        dict_3['test'] = 'comp: SM<IFG'
        dict_3['t_statistic'] = res.statistic
        dict_3['p_value'] = res.pvalue
        dict_3['df'] = (len(maxlags[1])+len(maxlags[0]))-2  
        # dict_3['perm_res'] = perm_test(argmax_2, argmax_1)
        res_list.append(dict_3)

## panel C

MODES = ['prod']

ROIS = ['dSM','mSM','vSM']

yheights = [0.33, 0.31, 0.29]
colors = [(204/255,204/255,0),(1,128/255,0),(153/255,76/255,0)] # change colors
scale = '0.3'
title = ''

emb = 'encoder'
l = 4

for mode in MODES:

    model = f"whisper-en-last-0.01-{mode}"

    vals = []
    errs = []
    maxlags = []

    for roi in ROIS:

        avgs = []

        # select electrodes

        df = filter_datum(model, base_df)
        roi_df = df[df.roi_2 == roi]

        # aggregate data

        for index, row in roi_df.iterrows():

            avgs.append(get_encoding_avg(emb, l, row.sid, row.elec_1, mode))

        avgs_df = pd.DataFrame(avgs)

        vals.append(avgs_df.mean(axis=0).tolist())
        errs.append(avgs_df.sem(axis=0, ddof=0))
        maxlags.append(avgs_df.idxmax(axis=1))

    # plotting

    plot_temporal_encoding(vals,errs,maxlags,yheights,colors,scale,title)

    fpath = "/figure_5"
    fname = f"/figure_5C_{mode}.svg"

    plt.savefig((OUT_PATH+fpath+fname)) 

    # statistical test

    dict_1 = {}
    dict_2 = {}
    dict_3 = {}

    dict_1['roi'] = 'dSM'
    dict_1['mode'] = mode
    dict_1['mean'] = maxlags[0].mean()
    dict_1['SD'] = maxlags[0].std()

    dict_2['roi'] = 'mSM'
    dict_2['mode'] = mode
    dict_2['mean'] = maxlags[1].mean()
    dict_2['SD'] = maxlags[1].std()

    dict_3['roi'] = 'vSM'
    dict_3['mode'] = mode
    dict_3['mean'] = maxlags[2].mean()
    dict_3['SD'] = maxlags[2].std()

    if mode == 'prod':

        # 1) dSM < mSM
        res = ttest_ind(maxlags[0],maxlags[1],axis=0,alternative='less')
        dict_1['test'] = 'prod: dSM<mSM'
        dict_1['t_statistic'] = res.statistic
        dict_1['p_value'] = res.pvalue
        dict_1['df'] = (len(maxlags[0])+len(maxlags[1]))-2
        # dict_1['perm_res'] = perm_test(argmax_1, argmax_2)
        res_list.append(dict_1)

        # 2) mSM < vSM
        res = ttest_ind(maxlags[1],maxlags[2],alternative='less')
        dict_2['test'] = 'prod: mSM<vSM'
        dict_2['t_statistic'] = res.statistic
        dict_2['p_value'] = res.pvalue
        dict_2['df'] = (len(maxlags[1])+len(maxlags[2]))-2  
        # dict_2['perm_res'] = perm_test(argmax_2, argmax_3)
        res_list.append(dict_2)

        # 3) dSM < vSM
        res = ttest_ind(maxlags[0],maxlags[2],alternative='less')
        dict_3['test'] = 'prod: dSM<vSM'
        dict_3['t_statistic'] = res.statistic
        dict_3['p_value'] = res.pvalue
        dict_3['df'] = (len(maxlags[0])+len(maxlags[2]))-2  
        # dict_3['perm_res'] = perm_test(argmax_1, argmax_3)
        res_list.append(dict_3)

res_df = pd.DataFrame(res_list)

fpath = '/figure_5'
fname = 'stat_test.csv'
res_df.to_csv((OUT_PATH+fpath+fname),index=False,header=True)
    

In [None]:
## FIGURE 6 ##

MODELS = ['whisper-en-last-prod-comp-contrast-0.01', 'whisper-de-best-prod-comp-contrast-0.01']

for model in MODELS:

    ## plot brainmap

    # select electrodes

    effect_df = filter_datum(model, base_df)
    effect_df.reset_index(drop=True,inplace=True)
 
    effect = []

    if 'en-last' in model:
        emb = 'encoder'
        l = 4
    elif 'de-best' in model:
        emb = 'decoder'
        l = 3

    mode1 = 'prod'
    mode2 = 'comp'

    # aggregate data
    
    for index, row in effect_df.iterrows():
        effect.append(get_encoding_max(emb,l,row.sid,row.elec_1,mode1)-get_encoding_max(emb,l,row.sid,row.elec_1,mode2))

    effect_df['effect'] = effect

    effect_df_pos = effect_df[effect_df.effect>0.01].reset_index(drop=True)
    effect_df_neg = effect_df[effect_df.effect<0.01].reset_index(drop=True)

    cbar_min=0.01
    cbar_max=0.15
    cbar_titles = ['Δ corr pos','Δ corr neg']
    colorscales = {cbar_titles[0]:[[0,'rgb(238,130,238)'],[1,'rgb(128,0,128)']],cbar_titles[1]:[[0,'rgb(135,206,1)'],[1,'rgb(0,102,0)']]}
    effect_dfs = {cbar_titles[0]:effect_df_pos,cbar_titles[1]:effect_df_neg}

    fpath = "/scratch/gpfs/kw1166/247-plotting"
    fname = f"/figure_6_brainmap_{emb}.svg"

    # plotting

    plot_brainmap(effect_dfs,cbar_titles,cbar_min,cbar_max,colorscales,out_path=(fpath+fname))

    ## plot encoding for example electrodes

    # select electrodes

    # fname = '/data/example_electrodes.csv'
    # example_elec_df = pd.read_csv((PATH+fname))
    # example_elec_df = filter_datum(model,example_elec_df)

    # for index, row in example_elec_df.iterrows():

    #     # aggregate data
         
    #     vals = [0] * 2
    #     errs = [0] * 2
    #     vals[0], errs[0] = get_encoding(emb,l,row.sid,row.elec_1,mode1)
    #     vals[1], errs[1] = get_encoding(emb,l,row.sid,row.elec_1,mode2)

    #     folds1 = get_encoding_folds(emb,l,row.sid,row.elec_1,mode1)
    #     folds2 = get_encoding_folds(emb,l,row.sid,row.elec_1,mode2)

    #     x_allfolds = np.transpose(np.array(folds1))
    #     y_allfolds = np.transpose(np.array(folds2))

    #     # stat test

    #     q_vals = run_permutation_test(x_allfolds,y_allfolds,0.05)
    #     siglags = q_vals[0].nonzero()[0]

    #     # plotting
        
    #     colors = [(128/255,0,128/255), (0,100/255,0)]
    #     scale = str(row.scale)
    #     title = row.roi + ' (' + row.elec_1 + ')'
        
    #     plot_encoding(vals,errs,siglags,colors,scale,title)

    #     fpath = "/figure_6"
    #     fname = f"/figure_6_{emb}_{row.elec_1}.svg"
    #     plt.savefig((OUT_PATH+fpath+fname)) 

In [None]:
## SUPPLEMENTARY FIGURE 1 ##

# t-SNE plots
tsne_df = load_pickle(f"/scratch/gpfs/kw1166/247-plotting/results/20230612-whisper-tsne-no-filter/all4-whisper-tsne-ave.pkl")
tsne_df["marker"] = 1

plot_dict = {
    "place_artic": "place_of_articulation",
    "manner_artic": "manner_of_articulation",
}

for cat in plot_dict.keys():
    plot_tsne(tsne_df, "en_x", "en_y", cat, f"speech-{plot_dict[cat]}")
    plot_tsne(tsne_df, "de_x", "de_y", cat, f"language-{plot_dict[cat]}")

# Get dataframe for classification results
class_type = [ # types of classifier (do not change)
    "1speech",
    "2language",
    "3control",
    "uniform",
    "strat",
] * 4
class_type_plt = ["1speech", "2language", "3control"] # types to plot (can change)
class_cat = np.repeat(["1Phoneme", "2PoA", "3MoA", "4PoS"], len(class_type) / 4) # categories

class_results_df = pd.DataFrame()
control = True
for layer in np.arange(0, 5):
    class_temp_df = get_classifier_results(
        f"../results/20230612-whisper-tsne-no-filter/classifier_pca50_filter-100_ave_L{layer:01}.csv",
        class_type,
        class_type_plt,
        class_cat,
    )
    if control:
        control = False
    else:
        class_temp_df = class_temp_df[~class_temp_df.class_type.str.contains("control")]
    class_temp_df.class_type = class_temp_df.class_type + f"{layer:01}"
    class_results_df = pd.concat([class_results_df, class_temp_df])

lang_filter = class_results_df.class_type.str.contains("language")
speech_filter = class_results_df.class_type.str.contains("speech")
class_speech_df = class_results_df[~lang_filter]
class_lang_df = class_results_df[~speech_filter]

# Plot classifier bar plots
class_bar_colors = ["paleturquoise", "darkturquoise", "dodgerblue", "blue", "darkblue", "grey"]
class_bar_name = "../results/20230612-whisper-tsne-no-filter/barplot-lang.svg"
plot_classifier_bar(class_lang_df, class_bar_name, class_bar_colors)

class_bar_colors = ["mistyrose","lightcoral","indianred","tomato","red","grey"]
class_bar_name = "../results/20230612-whisper-tsne-no-filter/barplot-speech.svg"
plot_classifier_bar(class_speech_df, class_bar_name, class_bar_colors)

In [None]:
## SUPPLEMENTARY FIGURE 2 ##

# TODO

In [None]:
## SUPPLEMENTARY FIGURE 3 ##

MODELS = ['whisper-en-last-whisper-de-best-contrast-0.01-prod', 'whisper-en-last-whisper-de-best-contrast-0.01-comp']
ROIS = ['preCG','postCG','TP','STG','IFG','pMTG','AG']

for model in MODELS:

    # select electrodes

    df = filter_datum(model, base_df)

    emb1 = 'encoder'
    l1 = 4
    emb2 = 'decoder'
    l2 = 3
    mode = model[-4:]

    for roi in ROIS:

        roi_df = df[df.roi_1 == roi]

        avgs1 = []
        avgs2 = []

        folds1 = []
        folds2 = []

        # aggregate data
    
        for index, row in roi_df.iterrows():

            avgs1.append(get_encoding_avg(emb1,l1,row.sid,row.elec_1,mode))
            avgs2.append(get_encoding_avg(emb2,l2,row.sid,row.elec_1,mode))

            folds1.append(get_encoding_folds(emb1,l1,row.sid,row.elec_1,mode))
            folds2.append(get_encoding_folds(emb2,l2,row.sid,row.elec_1,mode))

        avgs1_df = pd.DataFrame(avgs1)
        avgs2_df = pd.DataFrame(avgs2)

        folds1_df = pd.concat(folds1,axis=1)
        folds2_df = pd.concat(folds2,axis=1)        

        vals = [] 
        vals.append(avgs1_df.mean(axis=0).tolist())
        vals.append(avgs2_df.mean(axis=0).tolist())

        errs = []
        errs.append(avgs1_df.sem(axis=0, ddof=0))
        errs.append(avgs2_df.sem(axis=0, ddof=0))

        # stat test

        x_allfolds = np.transpose(np.array(folds1_df.values))
        y_allfolds = np.transpose(np.array(folds2_df.values))

        x = np.array(avgs1_df.values)
        y = np.array(avgs2_df.values)

        q_vals = run_permutation_test(x_allfolds,y_allfolds,0.05)
        siglags = q_vals[0].nonzero()[0]

        # plotting
        
        colors = [(1,0,0), (0,0,1)]
        scale = '0.35'
        title = roi + ' (' + str(len(roi_df.index)) + ')'
        plot_encoding(vals,errs,siglags,colors,scale,title)

        fpath = "/supp_figure_3"
        fname = f"/supp_figure_3_{roi}_{mode}.svg"

        plt.savefig((OUT_PATH+fpath+fname))         

In [None]:
## SUPPLEMENTARY FIGURE 4 ##

MODELS = ['whisper-en-last-prod-comp-contrast-0.01', 'whisper-de-best-prod-comp-contrast-0.01']
ROIS = ['preCG','postCG','TP','STG','IFG','pMTG','AG']

for model in MODELS:

    # select electrodes

    df = filter_datum(model, base_df)

    if 'en-last' in model:
        model = 'encoder'
        l = 4
    elif 'de-best' in model:
        model = 'decoder'
        l = 3

    mode1 = 'prod'
    mode2 = 'comp'

    for roi in ROIS:

        roi_df = df[df.roi_1 == roi]

        avgs1 = []
        avgs2 = []

        folds1 = []
        folds2 = []

        # aggregate data
    
        for index, row in roi_df.iterrows():

            avgs1.append(get_encoding_avg(model, l, row.sid, row.elec_1, mode1))
            avgs2.append(get_encoding_avg(model, l, row.sid, row.elec_1, mode2))

            folds1.append(get_encoding_folds(model, l, row.sid, row.elec_1, mode1))
            folds2.append(get_encoding_folds(model, l, row.sid, row.elec_1, mode2))

        avgs1_df = pd.DataFrame(avgs1)
        avgs2_df = pd.DataFrame(avgs2)

        folds1_df = pd.concat(folds1,axis=1)
        folds2_df = pd.concat(folds2,axis=1)        

        vals = [] 
        vals.append(avgs1_df.mean(axis=0).tolist())
        vals.append(avgs2_df.mean(axis=0).tolist())

        errs = []
        errs.append(avgs1_df.sem(axis=0, ddof=0))
        errs.append(avgs2_df.sem(axis=0, ddof=0))

        # stat test

        x_allfolds = np.transpose(np.array(folds1_df.values))
        y_allfolds = np.transpose(np.array(folds2_df.values))

        x = np.array(avgs1_df.values)
        y = np.array(avgs2_df.values)

        q_vals = run_permutation_test(x_allfolds,y_allfolds,0.05)
        siglags = q_vals[0].nonzero()[0]

        # plotting
        
        colors = [(128/255,0,128/255), (0,100/255,0)]
        scale = '0.35'
        title = roi + ' (' + str(len(roi_df.index)) + ')'
        plot_encoding(vals,errs,siglags,colors,scale,title)

        fpath = "/supp_figure_4"
        fname = f"/supp_figure_4_{roi}_{model}.svg"


In [None]:
## SUPPLEMENTARY FIGURE 5 ##

MODES = ['prod','comp']

ROIS = ['STG','TP','postCG','preCG','IFG','AG','pMTG']
ENCODER = ['STG','TP','preCG','postCG']
DECODER = ['IFG','AG','pMTG']

yheights = [0.39, 0.37, 0.35, 0.33, 0.29, 0.27, 0.31]
colors = [(1,0,0),(1,102/255,102/255),(1,128/255,0),(1,153/255,51/255),(0,0,1),(0,128/255,1),(51/255,153/255,1)]
scale = '0.4'
title = ''

res_list = []

for mode in MODES:

    vals = []
    errs = []
    maxlags = []

    for roi in ROIS:

        avgs = []

        # select electrodes

        if roi in ENCODER:

            model = f"whisper-en-last-0.01-{mode}"
            emb = 'encoder'
            l = 4
            df = filter_datum(model, base_df)
            roi_df = df[df.roi_1 == roi]

        elif roi in DECODER:

            model = f"whisper-de-best-0.01-{mode}"
            emb = 'decoder'
            l = 3
            df = filter_datum(model, base_df)
            roi_df = df[df.roi_1 == roi]

        # aggregate data

        for index, row in roi_df.iterrows():

            avgs.append(get_encoding_avg(emb, l, row.sid, row.elec_1, mode))

        avgs_df = pd.DataFrame(avgs)
        vals.append(avgs_df.mean(axis=0).tolist())
        errs.append(avgs_df.sem(axis=0, ddof=0))
        maxlags.append(avgs_df.idxmax(axis=1))

        # stat test

        res_dict = {}

        if mode == 'prod':
            res = ttest_1samp(avgs_df.idxmax(axis=1),popmean=0,alternative='two-sided')
        elif mode == 'comp':
            res = ttest_1samp(avgs_df.idxmax(axis=1),popmean=0,alternative='two-sided')

        res_dict['mode'] = mode
        res_dict['model'] = emb
        res_dict['ROI'] = roi
        res_dict['mean'] = np.mean(avgs_df.idxmax(axis=1))
        res_dict['SD'] = np.std(avgs_df.idxmax(axis=1))
        res_dict['statistic'] = res.statistic
        res_dict['df'] = res.df
        res_dict['pval'] = res.pvalue

        res_list.append(res_dict)

    # plotting

    plot_temporal_encoding(vals,errs,maxlags,yheights,colors,scale,title)

    fpath = "/supp_figure_5"
    fname = f"/supp_figure_5_{mode}.svg"

    plt.savefig((OUT_PATH+fpath+fname)) 
    
res_df = pd.DataFrame(res_list)

fpath = '/supp_figure_5'
fname = 'stat_test.csv'
res_df.to_csv((OUT_PATH+fpath+fname),index=False,header=True)

In [None]:
## SUPPLEMENTARY FIGURE 6 ##


MODELS = ['whisper-en-last-whisper-de-best-contrast-0.01-prod', 'whisper-en-last-whisper-de-best-contrast-0.01-comp']
ROIS = ['preCG','postCG','TP','STG','IFG','pMTG','AG']

for model in MODELS:

    # select electrodes

    df = filter_datum(model, base_df)

    model1 = 'encoder'
    model2 = 'decoder'
    model3 = 'full'
    mode = model[-4:]

    for roi in ROIS:

        roi_df = df[df.roi_1 == roi]

        avgs1 = []
        avgs2 = []
        avgs3 = []

        # aggregate data
    
        for index, row in roi_df.iterrows():

            avgs1.append(get_encoding_avg(model1, 4, row.sid, row.elec_1, mode))
            avgs2.append(get_encoding_avg(model2, 3, row.sid, row.elec_1, mode))
            avgs3.append(get_encoding_avg(model3, 3, row.sid, row.elec_1, mode))


        avgs1_df = pd.DataFrame(avgs1)
        avgs2_df = pd.DataFrame(avgs2)
        avgs3_df = pd.DataFrame(avgs3)

        vals = [] 
        vals.append(avgs1_df.mean(axis=0).tolist())
        vals.append(avgs2_df.mean(axis=0).tolist())
        vals.append(avgs3_df.mean(axis=0).tolist())

        errs = []
        errs.append(avgs1_df.sem(axis=0, ddof=0))
        errs.append(avgs2_df.sem(axis=0, ddof=0))
        errs.append(avgs3_df.sem(axis=0, ddof=0))

        siglags = []

        # plotting
        
        colors = [(1,0,0), (0,0,1), (0,(195/255),0)]
        scale = '0.35'
        title = roi + ' (' + str(len(roi_df.index)) + ')'
        plot_encoding(vals,errs,siglags,colors,scale,title)

        fpath = "/supp_figure_6"
        fname = f"/supp_figure_6_{roi}_{mode}.svg"

        plt.savefig((OUT_PATH+fpath+fname))    