# Paper figures

In [None]:
import utils
import glob
import pickle
import numpy as np
import pandas as pd
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

from skimage import measure
from scipy.stats import permutation_test, norm
from statsmodels.stats.multitest import multipletests
from nilearn.plotting import plot_markers
from util.ise import ROIS, ALL_ROIs, getMaps
from util.stats import mean_diff_axis, calculate_pvalues, permute_differences, cdf_pvalues, cdf_pvals2
from util.path import derivpath
from util.plot import formatenc, formatim
from utils import load_pickle

plt.style.use('presentation.mplstyle')
colors = sns.color_palette()
RED, BLUE = colors[3], colors[0]
YELLOW, GRAY = 'goldenrod', colors[7]  # goldenrod: 218., 165.,  32.
RED_CMAP = sns.light_palette(RED, as_cmap=True)
BLUE_CMAP = sns.light_palette(BLUE, as_cmap=True)
fs1c = (2.24, 2.24/(4/3))
fs2c = (4.76, 4.76/(4/3))
fs3c = (7.24, 7.24/(4/3))
gr = 1 - np.sqrt(5)/2 + .5

In [None]:
# Options
root = '../dataset/derivatives/encoding/'
Ss = [1, 2, 3, 4, 5, 6, 9, 10, 11, 12]
modes = ['prod', 'comp']

tmax, window, jump = 4, .250, .03125*2
lags = np.arange(-tmax, tmax+jump, jump)

rois = list(ROIS.keys())

models = []
# models.append('model-llama-7b_maxlen-1024_layer-16_reg-l2')
# models.append('model-bert-large-wwm_layer-12_reg-l2')  # actual bert
# models.append('model-gpt2-xl_maxlen-0_reg-l2')  # static gpt2
# models.append('model-gpt2-xl_maxlen-1024_layer-24_random_reg-l2')  # untrained gpt2
# models.append('model-random_ndim-1600_reg-l2')  # random embeddings
models.append('model-gpt2-xl_maxlen-1024_layer-24_reg-l2')  # actual gpt2

modelname = models[-1]
parts = utils.getparts(modelname)
shortname = '_'.join(map(str, list(parts.values())[:4]))

print(modelname, shortname)

# Load data

In [None]:
method, alpha = 'fdr_bh', .01
datatype = f'method-{method}_alpha-{alpha}_lags-1'
sigmodelname = 'model-gpt2-xl_maxlen-0_reg-l2_perm-phase'
p = derivpath(f'sub-all_model-{sigmodelname}.pkl', derivative='electrode-selection', datatype=datatype)
sigmasks = load_pickle(p)
sigmasks = {key[:2]: value for key, value in sigmasks.items()}
nsig_elecs = sum([s.sum() for s in sigmasks.values()])
nsig_elecs, sigmasks.keys()

In [None]:
band = 'highgamma'
# outdir = f'../results/paper/{sigmodelname}_method-{method}_alpha-{alpha}/{band}/{modelname}'
outdir = f'../results/paper/{sigmodelname}_method-{method}_alpha-{alpha}/{modelname}'
!mkdir -p {outdir}
outdir

In [None]:
# Load actual results
import gc
results = {}
for sub in Ss:
    for mode in modes:
        for mname in models:
            actual_pickle = glob.glob(f'{root}sub-{sub:02d}/{mname}/sub-{sub:02d}_task-conversation_encoding_mode-{mode}_band-{band}.pkl')
            # actual_pickle = glob.glob(f'{root}sub-{sub:02d}/{mname}/sub-{sub:02d}_task-conversation_encoding_mode-{mode}.pkl')
            if len(actual_pickle):
                print(sub, mode, mname)
                with open(actual_pickle[0], 'rb') as f:
                    result = pickle.load(f)
                    # For encoding
                    del result['embs']
                    del result['df']
                    del result['args']
                    # del result['coefs']
                    results[(sub, mode, mname)] = result
                    gc.collect()
# print(result.keys())
# results.keys()

In [None]:
allcoords = utils.getallcoords(results, Ss, modelname)

# Figure 1 - method

In [None]:
sub, label, num = 6, 'blue', 0
# sub, label, num = 11, 'red', 3
coords = allcoords[sub]
# fig, ax = plt.subplots()

cmap = mpl.colors.ListedColormap(sns.color_palette().as_hex())
plot_markers(np.ones(len(coords))*num, coords,
             display_mode='l',
            #  figure=fig, axes=ax,
             colorbar=False, node_cmap=cmap, alpha=0.85,
             node_vmin=0, node_vmax=9, node_size=40)
plt.savefig(f'{outdir}/fig1-brain-{label}.svg')
plt.show()

# Figure 2 - within subject encoding

In [None]:
# plot_markers one axis is default to size [2.6, 2.3] in
vmin, vmax = 0, 0.3
vmin, vmax = 0, 0.4

sig = not True
fig, axes = plt.subplots(2, 5, figsize=(2.6*5, 2.3*2))
fig.subplots_adjust(wspace=0, hspace=0)
for i, mode in enumerate(modes):
    # for j, sec in enumerate([-.5, -.25, 0, .25, .5]):
    for j, sec in enumerate([-1, -.5, 0, .5, 1]):
        ax = axes[i, j]
        lag = (lags == sec).nonzero()[0].item()

        maxes = []
        coords = []
        for sub in Ss:
            corrs = np.mean(results[(sub, mode, modelname)]['corrs'], axis=0)[:, lag]
            mask = sigmasks[(sub, mode)] if sig else None
            coords.append(allcoords[sub][mask].squeeze())
            maxes.append(corrs[mask].squeeze())
        values = np.concatenate(maxes)
        coords = np.vstack(coords)
        order = np.argsort(values)
        print(sec, values.min(), values.max(), len(values), len(coords))
        plot_markers(values[order], coords[order], display_mode='l', figure=fig, axes=ax,
                     node_vmax=vmax, node_vmin=vmin, node_size=15,
                     colorbar=False, node_cmap='Reds' if i else 'Blues')

bar_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
fig.colorbar(mpl.cm.ScalarMappable(norm=bar_norm, cmap='Blues'), ax=axes)
fig.colorbar(mpl.cm.ScalarMappable(norm=bar_norm, cmap='Reds'), label='encoding performance (r)', ax=axes)

# fig.suptitle(modelname)
fig.savefig(f'{outdir}/fig2-encoding-brain-lags_sig-{sig}.svg')
plt.show()

In [None]:
# Plot average per ROI
fig, axes = plt.subplots(2, 4, dpi=120, figsize=(7.24, 1.68*2), sharey=True, sharex=True)

sig = None
sig = sigmasks

for roi, ax in zip(['all'] + rois[:-1], axes.flatten()):
    if roi != 'all':
        MS, NS, _ = getMaps(results, Ss, modelname=modelname, speakerROI=ROIS[roi], partnerROI=ROIS[roi], mode='within-prod', significant=sig, reduce=False)
        MP, _, NP = getMaps(results, Ss, modelname=modelname, speakerROI=ROIS[roi], partnerROI=ROIS[roi], mode='within-comp', significant=sig, reduce=False)
    else:
        MS, NS, _ = getMaps(results, Ss, modelname=modelname, speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, mode='within-prod', significant=sig, reduce=False)
        MP, _, NP = getMaps(results, Ss, modelname=modelname, speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, mode='within-comp', significant=sig, reduce=False)

    if n := np.count_nonzero(NS):
        mean = np.ma.mean(MS, axis=0)
        err = np.ma.std(MS, axis=0) / np.sqrt(n)
        ax.plot(lags, mean, label='speaker', color=BLUE)
        ax.fill_between(lags, mean-err, mean+err, color=BLUE, alpha=0.1)

    if n := np.count_nonzero(NP):
        mean = np.ma.mean(MP, axis=0)
        err = np.ma.std(MP, axis=0) / np.sqrt(n)
        ax.plot(lags, mean, label='listener', color=RED)
        ax.fill_between(lags, mean-err, mean+err, color=RED, alpha=0.1)

    formatenc(ax)
    ax.text(0.05, 1, f'N={sum(NS)}', color=BLUE, transform=ax.transAxes, alpha=1, ha='left', va='top',  weight='bold', fontsize=6)
    ax.text(0.05, .91, f'N={sum(NP)}', color=RED, transform=ax.transAxes, alpha=1, ha='left', va='top', weight='bold', fontsize=6)
    ax.set(title=roi)  # f'{roi} ({sum(NS)} {sum(NP)})')
    ax.set_ylim(0, .32)
    ax.set_xticks(lags[::lags.size//4])

axes[-1, -1].legend(loc='best', frameon=False)

fig.supylabel('encoding performance (r ± sem)')
fig.supxlabel('lag (s)')

fig.tight_layout()
fig.savefig(f'{outdir}/fig2-encoding-rois_sig-{sig is not None}.svg')
plt.show()

In [None]:
def lag_diff_statistic(x, y, axis):
    # x, y shapes are (perms, lags, elecs)
    # average over elecs, then calc absdiff in peak lags
    return np.abs(np.argmax(np.mean(x, axis=axis), axis=-1) - np.argmax(np.mean(y, axis=axis), axis=-1))


def mag_diff_statistic(x, y, axis):
    # x, y shapes are (perms, lags, elecs)
    # average over elecs, then calc diff in peaks
    return np.max(np.mean(x, axis=axis), axis=-1) - np.max(np.mean(y, axis=axis), axis=-1)

In [None]:
n_perms = 10000

records = []
sig = sigmasks
for roi in ['all'] + rois:
    if roi != 'all':
        MS, NS, _ = getMaps(results, Ss, modelname=modelname, speakerROI=ROIS[roi], partnerROI=ROIS[roi], mode='within-prod', significant=sig, reduce=False, reduce_elecs=False)
        MP, _, NP = getMaps(results, Ss, modelname=modelname, speakerROI=ROIS[roi], partnerROI=ROIS[roi], mode='within-comp', significant=sig, reduce=False, reduce_elecs=False)
    else:
        MS, NS, _ = getMaps(results, Ss, modelname=modelname, speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, mode='within-prod', significant=sig, reduce=False, reduce_elecs=False)
        MP, _, NP = getMaps(results, Ss, modelname=modelname, speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, mode='within-comp', significant=sig, reduce=False, reduce_elecs=False)

    # Exclude all masked entries
    MS = MS[~MS.mask.any(-1)].data
    MP = MP[~MP.mask.any(-1)].data

    meanS = MS.mean(0)
    meanP = MP.mean(0)
    bestLagS = meanS.argmax()
    bestLagP = meanP.argmax()

    lag_permres = permutation_test((MS, MP), statistic=lag_diff_statistic, vectorized=True, axis=0, permutation_type='independent', n_resamples=n_perms, alternative='greater')
    mag_permres = permutation_test((MS, MP), statistic=mag_diff_statistic, vectorized=True, axis=0, permutation_type='independent', n_resamples=n_perms)

    records.append({
        'roi': roi,
        'lag_prod': lags[bestLagS],
        'lag_comp': lags[bestLagP],
        'lag_diff': lag_permres.statistic * jump,
        'lag_pvalue_corr': lag_permres.pvalue,
        'mag_prod': meanS[bestLagS],
        'mag_comp': meanP[bestLagP],
        'mag_diff': mag_permres.statistic,
        'mag_pvalue_corr': mag_permres.pvalue,
    })

df = pd.DataFrame(records)
df['lag_pvalue_corr'] = multipletests(df.lag_pvalue_corr.values, method='fdr_bh')[1]
df['mag_pvalue_corr'] = multipletests(df.mag_pvalue_corr.values, method='fdr_bh')[1]
# df['lag_pvalue_corr'] = [df.lag_pvalue_corr[0].item()] + multipletests(df.lag_pvalue_corr.values[1:], method='fdr_bh')[1].tolist()
# df['mag_pvalue_corr'] = [df.mag_pvalue_corr[0].item()] + multipletests(df.mag_pvalue_corr.values[1:], method='fdr_bh')[1].tolist()
df

In [None]:
# Plot each ROI separately both prod + comp for shcematics mostly
# plot_markers one axis is default to size [2.6, 2.3] in
# fig, axes = plt.subplots(2, 4, dpi=120, figsize=(14, 4))
fig, axes = plt.subplots(2, 4, figsize=(2.6*5, 2.3*2))
for roi, ax in zip(['all'] + list(ROIS.keys()), axes.flatten()):

    maxes = []
    coords = []
    for sub in Ss:
        rois = ROIS[roi] if roi != 'all' else sum(list(ROIS.values()), ())
        roimask = np.in1d(results[(sub, 'prod', modelname)]['rois'], rois)
        mask = sigmasks[(sub, 'prod')] & roimask
        mask |= sigmasks[(sub, 'comp')] & roimask
        # corrs = np.mean(results[(sub, mode, modelname)]['corrs'], 0)
        coords.append(allcoords[sub][mask])
        # maxes.append(corrs.max(-1)[mask])

    # values = np.concatenate(maxes)
    coords = np.vstack(coords)
    values = np.ones(len(coords)) * 7
    if len(values) == 1:
        values = np.repeat(values, 2)
        coords = np.repeat(coords, 2, axis=0)

    order = np.argsort(values)
    # print(values.min(), values.max(), len(values))

    plot_markers(values[order], coords[order], display_mode='l', figure=fig, axes=ax,
                node_vmax=9, node_vmin=0, node_size=20, colorbar=False,
                node_cmap=mpl.colors.ListedColormap(sns.color_palette().as_hex()))
    ax.set_title(roi)

fig.savefig(f'{outdir}/fig2-brain-roi-sig-schematics.svg')
plt.show()

# Figure 3 - intersubject encoding

In [None]:
# save for sig test later
M, ns, ms = getMaps(results, Ss, modelname=modelname, significant=sigmasks, reduce=False, reduce_folds=False)
np.savez(f'{outdir}/data-fig3-ise.npz', observed=M)

In [None]:
# Load permutated null distribution
# this can be generated from sigtest.py
path = derivpath(f'sub-all_model-{modelname}_perm-phase.npz', derivative='ise')
file = np.load(path.fpath)
observed = file['observed']
null_distribution = file['null_distribution']
null_distribution.shape, observed.shape

In [None]:
# Calculate pvalue for max
obsmax = observed.max(keepdims=True)
null_maxes = null_distribution.max((1, 2), keepdims=True)[..., 0]
# pvalues = calculate_pvalues(observed.max(), null_maxes, alternative='greater')
pvalues = cdf_pvalues(obsmax, null_maxes, alternative='greater').item()
threshold_max = norm.ppf(.99, loc=null_maxes.mean(), scale=null_maxes.std())
immask = observed <= threshold_max
pvalues, threshold_max

In [None]:
# Plot thresholdedimage

threshold_max = 0.052362467174286775  # from main model

M = np.ma.array(observed)
immask = observed <= threshold_max  # for plotting we'll be more strict
M.mask = immask
print(M.max())
mx = .15

w = 7.24 / 2 - .5
fig, ax = plt.subplots(figsize=(w, w/(4/3)))
im = ax.imshow(M, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags, cbar_loc='v')
ax.set(xlabel='listener lag (s)', ylabel='speaker lag (s)')

# utils.plot_outlines((M > threshold).T, ax=ax, color='white', alpha=0.7, zorder=3)

print('Max is at', lags[np.array(np.unravel_index(M.argmax(), M.shape))])
fig.savefig(f'{outdir}/fig3s-ise-all-threshold.svg')
plt.show()

In [None]:
# Plot main ISE result
M, ns, ms = getMaps(results, Ss, modelname=modelname, significant=sigmasks, reduce=True, weight=True)
mx = M.max()
print(mx)
mx = .05
mx = .15

w = 7.24 / 2 - .5
fig, ax = plt.subplots(figsize=(w, w/(4/3)))  # half 3c
im = ax.imshow(M, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags, cbar_loc='v')
ax.set(xlabel='listener lag (s)', ylabel='speaker lag (s)')
print('Max is at', lags[np.array(np.unravel_index(M.argmax(), M.shape))])

fig.savefig(f'{outdir}/fig3-ise-all.svg')
plt.show()

In [None]:
# Plot smaller version of it
mx = .15

fig, ax = plt.subplots(figsize=fs1c)
im = ax.imshow(M, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags)
ax.set(xlabel='listener lag (s)', ylabel='speaker lag (s)')

fig.savefig(f'{outdir}/fig3-ise-all-small.svg')
plt.show()

In [None]:
MS, NS, _ = getMaps(results, Ss, modelname=modelname, mode='within-prod', speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, significant=sigmasks, reduce=not True, weight=False)
MP, _, NP = getMaps(results, Ss, modelname=modelname, mode='within-comp', speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, significant=sigmasks, reduce=not True, weight=False)

# print(MS.max(), MP.max())
# Plot within-subject lines for one ROI
# i, j, = rois.index(speakerROI), rois.index(partnerROI)
# Z = M[i, j]
Z = M

fig, ax = plt.subplots(figsize=(w, w/(4/3)))

if n := np.count_nonzero(NS):
    mean = np.ma.mean(MS, axis=0)
    err = np.ma.std(MS, axis=0) / np.sqrt(n)
    ax.plot(lags, mean, label='speaker', color=BLUE, alpha=0.7, zorder=2.1)
    ax.fill_between(lags, mean-err, mean+err, color=BLUE, alpha=0.1, zorder=2)
    MS = mean

if n := np.count_nonzero(NP):
    mean = np.ma.mean(MP, axis=0)
    err = np.ma.std(MP, axis=0) / np.sqrt(n)
    ax.plot(lags, mean, label='listener', color=RED, alpha=0.7, zorder=2.1)
    ax.fill_between(lags, mean-err, mean+err, color=RED, alpha=0.1, zorder=2)
    MP = mean

# z = Z > np.quantile(Z, .98)
z = M > threshold_max
print(z.sum())

z0 = z.any(0).nonzero()[0]
z1 = z.any(1).nonzero()[0]

a = np.linalg.norm(Z, axis=0)[z0]
# a0 = np.interp(a, (a.min(), a.max()), (0.1, 1))
a0 = np.interp(a, (a.min(), a.max()), (2, 10))

a = np.linalg.norm(Z, axis=1)[z1]
# a1 = np.interp(a, (a.min(), a.max()), (0.1, 1))  # for alpha
a1 = np.interp(a, (a.min(), a.max()), (2, 10))  # for size

ax.scatter(lags[z1], MS[z1], alpha=1, zorder=2.9, s=a1, color=BLUE)
ax.scatter(lags[z0], MP[z0], alpha=1, zorder=2.9, s=a0, color=RED)

a = Z[np.nonzero(z)]
a2 = np.interp(a, (a.min(), a.max()), (0.1, .2))
a = a.argsort()
a3 = np.interp(a, (a.min(), a.max()), (2.1, 2.8))

# three color lines for before, cross, and after
a, b = np.nonzero(z)
for i, x, y in zip(range(len(a)), a, b):
    # lx, ly = lags[[x, y]]
    # if lx <= 0 and ly <= 0:
    #     c = 'orange'
    #     # ax.plot(lags[[x,y]], [MS[x], MP[y]], c=c, alpha=a2[i], zorder=a3[i])
    # elif lx <=0 and ly > 0:
    #     c = 'green'
    #     # ax.plot(lags[[x,y]], [MS[x], MP[y]], c=c, alpha=a2[i], zorder=a3[i])
    # else:
    #     c = 'purple'
    #     ax.plot(lags[[x,y]], [MS[x], MP[y]], c=c, alpha=a2[i], zorder=a3[i])
    line = ax.plot(lags[[x,y]], [MS[x], MP[y]], c=np.array([64, 64, 64])/255, alpha=a2[i], zorder=a3[i])

line[0].set_label('ISE')

ax.text(0.05, 1, f'N={sum(NS)}', color=BLUE, transform=ax.transAxes, alpha=1, ha='left', va='top',  weight='bold', fontsize=6)
ax.text(0.05, .93, f'N={sum(NP)}', color=RED, transform=ax.transAxes, alpha=1, ha='left', va='top', weight='bold', fontsize=6)

formatenc(ax)
ax.set(xlabel='lag (s)', ylabel='pearson correlation', ylim=(-.01, .25))
ax.legend(frameon=False, loc='upper right')
# ax.set_title(f'ISE ({sum(NS)} {sum(NP)})')
ax.set_ylim(0, ax.get_ylim()[1])
fig.tight_layout()
fig.savefig(f'{outdir}/fig3-ise-all2d.svg')
plt.show()

In [None]:
M, ns, ms = getMaps(results, Ss, modelname=modelname, mode='direct', significant=sigmasks, reduce=True, weight=True)

mx = M.max()
print(mx)
mx = .05
mx = .15

fig, ax = plt.subplots(figsize=fs1c)
im = ax.imshow(M, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx, aspect='equal')
formatim(fig, ax, im, lags, cbar_loc='v', cbar_sym=False)
ax.set(xlabel='listener lag (s)', ylabel='speaker lag (s)')

# fig.tight_layout()
fig.savefig(f'{outdir}/fig3-isfc-all.svg')
plt.show()

In [None]:
# Plot one brain for speaker, listener max per electrode
maxcorrs = {}
maxcoords = {}
for mode in modes:
    maxes = []
    coords = []
    trfs = []
    for sub in Ss:
        mask = sigmasks[(sub, mode)]
        corrs = np.mean(results[(sub, mode, modelname)]['corrs'], axis=0)
        coords.append(allcoords[sub][mask])
        maxes.append(corrs.max(-1)[mask])
        trfs.append(corrs[mask])
    maxcorrs[mode] = np.concatenate(maxes)
    maxcoords[mode] = np.vstack(coords)

fig, axes = plt.subplots(1, 2, figsize=(2.6*2, 2.3))
for i, mode, ax in zip(range(len(axes)), modes, axes):
    values = maxcorrs[mode]
    coords = maxcoords[mode]
    order = np.argsort(values)
    print(values.min(), values.max(), len(values))
    plot_markers(values[order], coords[order], display_mode='l', figure=fig, axes=ax,
                    node_vmax=0.5, node_vmin=vmin, node_size=20,
                    colorbar=not True, node_cmap='Reds' if i else 'Blues')
    # ax.set_title(mode)
# fig.suptitle(modelname)
fig.savefig(f'{outdir}/fig3-encoding-brain-max-prodcomp.svg')
plt.show()

## Figure S3.1 - matched/unmatched weights

In [None]:
i = 0
observed = []
null_distribution = np.empty((80, lags.size, lags.size))
for s in Ss:
    part = utils.getpartner(s)
    for p in Ss:
        if p != s and p != part:
            null_distribution[i], _, _ = getMaps(results, [s], [p], modelname=modelname, mode='weights', significant=sigmasks, reduce=False)
            i += 1
        elif p == part:
            observed.append(getMaps(results, [s], [p], modelname=modelname, mode='weights', significant=sigmasks, reduce=False)[0])
observed = np.vstack(observed)
M2 = np.mean(null_distribution, 0)
print(i)

In [None]:
M1, _, _ = getMaps(results, Ss, modelname=modelname, mode='weights', significant=sigmasks, reduce=True, weight=False)
M1.shape, M2.shape

In [None]:
##
w = 7.24 / 2 - .5
fig, axes = plt.subplots(1, 2, figsize=(w*1.75, 1.68), sharex=True, sharey=True)

print(M1.max(), M2.max())

mx = 0.25

ax = axes[0]
im = ax.imshow(M1, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx, aspect='equal')
formatim(fig, ax, im, lags, cbar_loc=None)
ax.set(xlabel='listener lag (s)', ylabel='speaker lag (s)', title='Matched Dyads')

ax = axes[1]
im = ax.imshow(M2, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx, aspect='equal')
formatim(fig, ax, im, lags, cbar_loc=None)
ax.set(xlabel='listener lag (s)', ylabel='speaker lag (s)', title='Unmatched Dyads')

cbar = fig.colorbar(im, ax=axes)
# cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, orientation='vertical')#, shrink=0.85)
cbar.ax.set_ylabel('pearson correlation')
cbar.set_ticks([im.get_clim()[0], 0, im.get_clim()[-1]])

fig.savefig(f'{outdir}/fig3s-weights.svg')
plt.show()

# Figure 4 - inter-regional ISE

In [None]:
rois = list(ROIS.keys())

In [None]:
# get actual result
M, nS, nP = getMaps(results, Ss, modelname=modelname, reduce=True, weight=True, significant=sigmasks, rois=rois)
M.shape, M.mask.any()

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(fs2c[0], fs3c[0]))

Z = nS.sum(0)
im = axes[0].imshow(Z, cmap='Blues', vmin=0)
axes[0].set_title('num speaker electrodes')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[0].text(j, i, str(Z[i, j]), va='center', ha='center')

Z = nP.sum(0)
im = axes[1].imshow(Z, cmap='Blues', vmin=0)
axes[1].set_title('num listener electrodes')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[1].text(j, i, str(Z[i, j]), va='center', ha='center')

Z = (nP > 0).sum(0)
im = axes[2].imshow(Z, cmap='Blues', vmin=0)
axes[2].set_title('num subjects')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[2].text(j, i, str(Z[i, j]), va='center', ha='center')

x = np.arange(len(rois))
for ax in axes:
    ax.set(xticks=x, yticks=x, yticklabels=rois)
    ax.set_xticklabels(rois, rotation=45)

# fig.supylabel('speaker ROIs')
# fig.supxlabel('listener ROIs')

fig.tight_layout()
fig.savefig(f'{outdir}/fig4s_mode-spklst_roicounts.svg')
plt.show()

In [None]:
subsiglimit = 4
submask = np.empty_like(M.mask)
limit = (nS > 0).sum(0) < subsiglimit
for i in range(8):
    for j in range(8):
        submask[i, j] = limit[i, j]
M.mask = submask

In [None]:
# Load null distribution
path = derivpath(f'sub-all_mode-spklst_model-{modelname}_perm-phase.npz', derivative='ise')
file = np.load(path.fpath)
null_distribution = file['null_distribution']
null_distribution.shape

In [None]:
# Calc pval per lag
pvalues = cdf_pvalues(M[None, ...], null_distribution).squeeze()
np.nanmin(pvalues)

In [None]:
# correct across lags and rois
method, alpha = 'bonf', .05

if method == 'bonf':
    # manually bonf to account for masked values
    n_tests = float(M.mask.size - M.mask.sum())
    alphacBonf = alpha / n_tests
    pvalues_ma = np.ma.array(pvalues, mask=M.mask)
    sigs = (pvalues_ma <= alphacBonf).filled(False)
else:
    sigs, _, _, _ = multipletests(pvalues.flatten(), method=method, alpha=alpha)
    sigs = sigs.reshape(M.shape)
    sigs = np.ma.array(sigs.reshape(M.shape), mask=M.mask)
    sigs = sigs.filled(False)

sigs.sum(), sigs.any((2,3)).sum()

In [None]:
# Ensure there at least `cc_limit` connected sig lags
cc_limit = 20
cc_n = np.zeros((8, 8), dtype=int)
sigs_label = np.zeros_like(sigs)
for i in range(8):
    for j in range(8):
        if sigs[i, j].any():
            labelmap = measure.label(sigs[i, j], connectivity=2)
            labels, counts = np.unique(labelmap, return_counts=True)

            sortedcounts = labels[counts[1:].argsort() + 1]
            sigs_label[i, j] = (labelmap == sortedcounts[-1])
            if len(sortedcounts) > 1:
                sigs_label[i, j] |= (labelmap == sortedcounts[-2])

            maxcount = counts[1:].max()  # skip label `0` corresponding to non-sig lag pairs
            cc_n[i, j] = maxcount
            if maxcount < cc_limit:
                sigs[i, j] = False
cc_n

In [None]:
sigs.sum((2,3))

In [None]:
# M.mask = M.mask | (~sigs)
M.mask = ~sigs
M.max()

In [None]:
mx = 0.1

fig = plt.figure(figsize=(7.25, 8))
gs0 = mpl.gridspec.GridSpec(2, 1, figure=fig, height_ratios=[2, 1], wspace=0.1)

# Left side
nrow, ncol = 8, 8
gs00 = gs0[0].subgridspec(8, 8, wspace=0.05, hspace=0.05)
for i in range(nrow):
    for j in range(ncol):
        ax = fig.add_subplot(gs00[i, j])
        if sigs[i,j].any():
            im = ax.matshow(M[i, j].data, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
            ax.spines.right.set_visible(True)
            ax.spines.top.set_visible(True)
        else:
            im = ax.matshow(M[i, j].data, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
            ax.spines.left.set_visible(False)
            ax.spines.bottom.set_visible(False)
        ax.set(xticks=[], yticks=[])
        if j == 0: ax.set_ylabel(rois[i])
        if i == nrow - 1: ax.set_xlabel(rois[j])

# Right side
gs01 = gs0[1].subgridspec(1, 3)
x = np.arange(len(rois))

ax = fig.add_subplot(gs01[0, 0])
Z = nS.sum(0)
ax.imshow(Z, cmap='Blues', vmin=0)
ax.set_title('num speaker electrodes')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        ax.text(j, i, str(Z[i, j]), va='center', ha='center')
ax.set(xticks=x, yticks=x)
ax.set_xticklabels(rois, rotation=45)
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)

ax = fig.add_subplot(gs01[0, 1])
Z = nP.sum(0)
ax.imshow(Z, cmap='Blues', vmin=0)
ax.set_title('num listener electrodes')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        ax.text(j, i, str(Z[i, j]), va='center', ha='center')
ax.set(xticks=x, yticks=())
ax.set_xticklabels(rois, rotation=45)
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)

ax = fig.add_subplot(gs01[0, 2])
Z = (nP > 0).sum(0)
ax.imshow(Z, cmap='Blues', vmin=0)
ax.set_title('num subjects')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        ax.text(j, i, str(Z[i, j]), va='center', ha='center')
ax.set(xticks=x, yticks=())
ax.set_xticklabels(rois, rotation=45)
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)

# colorbar
cax = plt.axes([0.4, 0.05, 0.2, 0.02])  # left, bottom, width, height
cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
cbar.set_ticks([-mx,  0, mx])
cbar.set_label('ISE (r)')

plt.show()

In [None]:
# Plot all parcels
mx = .1

fig, axes = plt.subplots(len(rois), len(rois), figsize=(6, 6))
nrow, ncol = 8, 8
gs = mpl.gridspec.GridSpec(nrow, ncol,
         wspace=0.05, hspace=0.05, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i, roi1 in enumerate(rois):
    for j, roi2 in enumerate(rois):
        ax = plt.subplot(gs[i, j])
        # im = ax.matshow(~M[i, j].mask, origin='lower', interpolation=None, cmap='Greys', vmin=0, vmax=1, aspect='equal')
        if sigs[i,j].any():
            im = ax.matshow(M[i, j].data, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
            ax.spines.right.set_visible(True)
            ax.spines.top.set_visible(True)
        else:
            im = ax.matshow(M[i, j].data, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
            ax.spines.left.set_visible(False)
            ax.spines.bottom.set_visible(False)
        ax.set(xticks=[], yticks=[])
        if j == 0: ax.set_ylabel(rois[i])
        if i == nrow - 1: ax.set_xlabel(rois[j])

# cbar = fig.colorbar(im, ax=gs, shrink=0.5, orientation='horizontal')
cax = plt.axes([0.4, -.05, 0.2, 0.02])  # left, bottom, width, height
cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
cbar.set_ticks([-mx,  0, mx])
cbar.set_label('ISE (r)')

fig.supylabel('speaker ROIs', x=-.01)
fig.supxlabel('listener ROIs', y=-.01)

fig.savefig(f'{outdir}/fig4s_ise-full_mode-spklst_onlysig-{onlysig}_roixroi.svg')
plt.show()

In [None]:
# Plot roi x roi
w = 7.24 / 2 - .5
fig, ax = plt.subplots(figsize=(w, w*(3/4)))

Mmax = M.max((2,3))
# Mmax.mask = ~sigs.any((2,3))
print(Mmax.min(), Mmax.max())
im = ax.imshow(Mmax, cmap='Reds', vmin=0)

# x, y = np.nonzero(sigs.any((2,3)))
# ax.scatter(y, x, marker='*', color='yellow', s=8)

ax.set(xlabel='listener ROI', ylabel='speaker ROI')
ax.set_xticks(range(len(rois)))
ax.set_yticks(range(len(rois)))
ax.set_xticklabels(rois, rotation=45)
ax.set_yticklabels(rois)
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('max ISE (r)')
# cbar.set_ticks([0, mx])

fig.savefig(f'{outdir}/fig4_mode-spklst_roisummary.svg')
plt.show()

## within-speaker

In [None]:
Ms, nSs, nPs = getMaps(results, Ss, Ss, modelname=modelname, reduce=True, weight=True, rois=rois, significant=sigmasks, partMode='prod')
Ms.shape

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(fs2c[0], fs2c[0]))

Z = nSs.sum(0)
im = axes[0].imshow(Z, cmap='Blues', vmin=0)
axes[0].set_title('within-speaker elec. count')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[0].text(j, i, str(Z[i, j]), va='center', ha='center')

Z = (nPs > 0).sum(0)
im = axes[1].imshow(Z, cmap='Blues', vmin=0)
axes[1].set_title('num subjects')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[1].text(j, i, str(Z[i, j]), va='center', ha='center')

x = np.arange(len(rois))
for ax in axes:
    ax.set(xticks=x, yticks=x, yticklabels=rois)
    ax.set_xticklabels(rois, rotation=45)

# fig.supylabel('speaker ROIs')
# fig.supxlabel('listener ROIs')

fig.tight_layout()
fig.savefig(f'{outdir}/fig4s_mode-spk_roicounts.svg')
plt.show()

In [None]:
submask = np.empty_like(Ms.mask)
limit = (nSs > 0).sum(0) < subsiglimit
for i in range(8):
    for j in range(8):
        submask[i, j] = limit[i, j]
Ms.mask = submask

In [None]:
# Load null distribution
path = derivpath(f'sub-all_mode-speaker_model-{modelname}_perm-phase.npz', derivative='ise')
file = np.load(path.fpath)
null_distribution = file['null_distribution']
null_distribution.shape

In [None]:
# Calc pval per lag
pvalues = cdf_pvalues(Ms[None, ...], null_distribution).squeeze()

if method == 'bonf':
    n_tests = float(Ms.mask.size - Ms.mask.sum())
    alphacBonf = alpha / n_tests
    pvalues = np.ma.array(pvalues, mask=Ms.mask)
    sigsS = (pvalues <= alphacBonf).filled(False)
else:
    sigsS, _, _, _ = multipletests(pvalues.flatten(), method=method, alpha=alpha)
    sigsS = np.ma.array(sigsS.reshape(Ms.shape), mask=Ms.mask)
    sigsS = sigsS.filled(False)

# Ensure there at least `cc_limit` connected sig lags
sigsS_label = np.zeros_like(sigsS)
for i in range(8):
    for j in range(8):
        if sigsS[i, j].any():
            labelmap = measure.label(sigsS[i, j], connectivity=2)
            labels, counts = np.unique(labelmap, return_counts=True)

            sortedcounts = labels[counts[1:].argsort() + 1]
            sigsS_label[i, j] = (labelmap == sortedcounts[-1])
            if len(sortedcounts) > 1:
                sigsS_label[i, j] |= (labelmap == sortedcounts[-2])

            if counts[1:].max() < cc_limit:
                print(rois[i], rois[j], counts[1:].max())
                sigsS[i, j] = False

sigsS = sigsS.reshape(Ms.shape)
sigsS.sum(), sigsS.any((2,3)).sum()

In [None]:
Ms.mask = ~ sigsS
Ms.max(), np.triu(Ms.max((2,3)).filled(0), 1).max()

In [None]:
# Plot roi x roi
w = 7.24 / 2 - .5
fig, ax = plt.subplots(figsize=(w, w*(3/4)))

palette = plt.cm.Blues.with_extremes(bad='#fafafa')

Mmax = Ms.max((2,3))
# Mmax.mask[np.triu_indices_from(Mmax, 1)] = True
print(Mmax.min(), Mmax.max())
im = ax.imshow(Mmax, cmap=palette, vmin=0)#, vmax=mx)

# x, y = np.nonzero(sigs.any((2,3)))
# ax.scatter(y, x, marker='*', color='yellow', s=8)

ax.set(xlabel='listener ROI', ylabel='speaker ROI')
ax.set_xticks(range(len(rois)))
ax.set_yticks(range(len(rois)))
ax.set_xticklabels(rois, rotation=45)
ax.set_yticklabels(rois)
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('max ISE (r)')
# cbar.set_ticks([0, mx])

fig.savefig(f'{outdir}/fig4_mode-spk_roisummary.svg')
plt.show()

In [None]:
# Plot all parcels within speaker
mx = .10

nrow, ncol = 8, 8
fig, axes = plt.subplots(nrow, ncol, figsize=(6, 6))
gs = mpl.gridspec.GridSpec(nrow, ncol,
         wspace=0.05, hspace=0.05, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i, roi1 in enumerate(rois):

    for j in range(i+1, len(rois)):
        ax = axes[i, j]
        ax.spines.left.set_visible(False)
        ax.spines.bottom.set_visible(False)
        ax.set(xticks=[], yticks=[])

    for j in range(i+1):
    # for j in range(8):
        rois2 = rois[j]
        ax = plt.subplot(gs[i, j])
        im = ax.matshow(Ms[i, j].data, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
        # im = ax.matshow(~Ms[i, j].mask, origin='lower', interpolation=None, cmap='Greys', vmin=0, vmax=1)
        ax.set(xticks=[], yticks=[])
        if sigsS[i,j].any():
            ax.spines.right.set_visible(True)
            ax.spines.top.set_visible(True)
        else:
            ax.spines.left.set_visible(False)
            ax.spines.bottom.set_visible(False)
        if j == 0: ax.set_ylabel(rois[i])
        if i == nrow - 1: ax.set_xlabel(rois[j])

# cbar = fig.colorbar(im, ax=axes, shrink=0.5)
cax = plt.axes([0.4, -.05, 0.2, 0.02])  # left, bottom, width, height
cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
cbar.set_ticks([-mx,  0, mx])
cbar.set_label('ISE (r)')

fig.supylabel('speaker ROIs', x=-.01)
fig.supxlabel('speaker ROIs', y=-.01)

fig.savefig(f'{outdir}/fig4s_ise-full_mode-spk_roixroi.svg')
plt.show()

## within-listener

In [None]:
Ml, nPs, nPl = getMaps(results, Ss, Ss, modelname=modelname, reduce=True, weight=True, rois=rois, significant=sigmasks, subMode='comp')
Ml.shape

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(fs2c[0], fs2c[0]))

Z = nPs.sum(0)
im = axes[0].imshow(Z, cmap='Blues', vmin=0)
axes[0].set_title('within-listener elec. count')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[0].text(j, i, str(Z[i, j]), va='center', ha='center')

Z = (nPs > 0).sum(0)
im = axes[1].imshow(Z, cmap='Blues', vmin=0)
axes[1].set_title('num subjects')
for i in range(Z.shape[0]):
    for j in range(Z.shape[1]):
        axes[1].text(j, i, str(Z[i, j]), va='center', ha='center')

x = np.arange(len(rois))
for ax in axes:
    ax.set(xticks=x, yticks=x, yticklabels=rois)
    ax.set_xticklabels(rois, rotation=45)

# fig.supylabel('speaker ROIs')
# fig.supxlabel('listener ROIs')

fig.tight_layout()
fig.savefig(f'{outdir}/fig4s_mode-lst_roicounts.svg')
plt.show()

In [None]:
submask = np.empty_like(Ml.mask)
limit = (nPs > 0).sum(0) < subsiglimit
for i in range(8):
    for j in range(8):
        submask[i, j] = limit[i, j]
Ml.mask = submask

In [None]:
# Load null distribution
path = derivpath(f'sub-all_mode-listener_model-{modelname}_perm-phase.npz', derivative='ise')
file = np.load(path.fpath)
null_distribution = file['null_distribution']
null_distribution.shape

In [None]:
# Calc pval per lag
pvalues = cdf_pvalues(Ml[None, ...], null_distribution).squeeze()

if method == 'bonf':
    n_tests = float(Ml.mask.size - Ml.mask.sum() / 2)
    alphacBonf = alpha / n_tests
    pvalues_ma = np.ma.array(pvalues, mask=Ml.mask)
    sigsL = (pvalues_ma <= alphacBonf).filled(False)
else:
    sigsL, pvals_corr, _, _ = multipletests(pvalues.flatten(), method=method, alpha=alpha)
    sigsL = sigsL.reshape(Ml.shape)
    sigsL = np.ma.array(sigsL.reshape(Ml.shape), mask=Ml.mask)
    sigsL = sigsL.filled(False)

# Ensure there at least `cc_limit` connected sig lags
sigsL_label = np.zeros_like(sigsL)
for i in range(8):
    for j in range(8):
        if sigsL[i, j].any():
            labelmap = measure.label(sigsL[i, j], connectivity=2)
            labels, counts = np.unique(labelmap, return_counts=True)

            sortedcounts = labels[counts[1:].argsort() + 1]
            sigsL_label[i, j] = (labelmap == sortedcounts[-1])
            if len(sortedcounts) > 1:
                sigsL_label[i, j] |= (labelmap == sortedcounts[-2])

            if counts[1:].max() < cc_limit:
                print(rois[i], rois[j], counts[1:].max())
                sigsL[i, j] = False

sigsL = sigsL.reshape(Ml.shape)
sigsL.sum(), sigsL.any((2,3)).sum()

In [None]:
Ml.mask = ~ sigsL
Ml.max(), np.triu(Ml.max((2,3)).filled(0), 1).max()

In [None]:
# Plot roi x roi
w = 7.24 / 2 - .5
fig, ax = plt.subplots(figsize=(w, w*(3/4)))

palette = plt.cm.Reds.with_extremes(bad='#fafafa')

Mmax = Ml.max((2,3))
Mmax.mask[np.triu_indices_from(Mmax, 1)] = True
print(Mmax.min(), Mmax.max())
im = ax.imshow(Mmax, cmap=palette, vmin=0)#, vmax=mx)

# x, y = np.nonzero(sigs.any((2,3)))
# ax.scatter(y, x, marker='*', color='yellow', s=8)

ax.set(xlabel='listener ROI', ylabel='speaker ROI')
ax.set_xticks(range(len(rois)))
ax.set_yticks(range(len(rois)))
ax.set_xticklabels(rois, rotation=45)
ax.set_yticklabels(rois)
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)
cbar = fig.colorbar(im, ax=ax)
cbar.set_label('max ISE (r)')
# cbar.set_ticks([0, mx])

fig.savefig(f'{outdir}/fig4_mode-lst_roisummary.svg')
plt.show()

In [None]:
# Plot all parcels within listener
mx = .15

nrow, ncol = 8, 8
fig, axes = plt.subplots(nrow, ncol, figsize=(6, 6))
gs = mpl.gridspec.GridSpec(nrow, ncol,
         wspace=0.05, hspace=0.05, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i, roi1 in enumerate(rois):
    for j in range(i+1, len(rois)):
        ax = axes[i, j]
        ax.spines.left.set_visible(False)
        ax.spines.bottom.set_visible(False)
        ax.set(xticks=[], yticks=[])

    for j in range(i+1):
        rois2 = rois[j]
        ax = plt.subplot(gs[i, j])
        im = ax.matshow(Ml[i, j].data, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
        ax.set(xticks=[], yticks=[])
        if sigsL[i,j].any():
            ax.spines.right.set_visible(True)
            ax.spines.top.set_visible(True)
        else:
            ax.spines.left.set_visible(False)
            ax.spines.bottom.set_visible(False)
        if j == 0: ax.set_ylabel(rois[i])
        if i == nrow - 1: ax.set_xlabel(rois[j])

# cbar = fig.colorbar(im, ax=axes, shrink=0.5)
cax = plt.axes([0.4, -.05, 0.2, 0.02])  # left, bottom, width, height
cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
cbar.set_ticks([-mx,  0, mx])
cbar.set_label('ISE (r)')

fig.supylabel('listener ROIs', x=-.01)
fig.supxlabel('listener ROIs', y=-.01)

fig.savefig(f'{outdir}/fig4s_ise-full_mode-lst_roixroi.svg')
plt.show()

## interregional heatmaps

In [None]:
from util.plot import plot_outlines

In [None]:
fig4rois = [
    ('spk', 'SM', 'STG'), ('spklst', 'SM', 'STG'), ('spklst', 'SM', 'IFG'), ('lst', 'STG', 'IFG'),
    ('spk', 'SM', 'IFG'), ('spklst', 'STG', 'STG'), ('spklst', 'ATL', 'ATL'), ('lst', 'STG', 'ATL')
]

In [None]:
mx = .15

fig, axes = plt.subplots(2, 4, figsize=(8.25, 5.25), sharex=True, sharey=True)

for ax, (mode, roi1, roi2) in zip(axes.flatten(), fig4rois):
    i, j = rois.index(roi1), rois.index(roi2)

    heatmap = M
    xlabel = 'listener lag (s)'
    ylabel = 'speaker lag (s)'

    if mode == 'spk':
        heatmap = Ms
        xlabel = 'speaker lag (s)'
        plot_outlines(sigsS_label[i, j].T, ax=ax, color='black', lw=0.5, ls='-', alpha=0.7, zorder=3)
    elif mode == 'lst':
        heatmap = Ml
        ylabel = 'listener lag (s)'
        plot_outlines(sigsL_label[i, j].T, ax=ax, color='black', lw=0.5, ls='-', alpha=0.7, zorder=3)
    else:
        plot_outlines(sigs_label[i, j].T, ax=ax, color='black', lw=0.5, ls='-', alpha=0.7, zorder=3)

    im = ax.imshow(heatmap[i, j].data, origin='lower', cmap='coolwarm', vmin=-mx, vmax=mx, aspect='equal')
    formatim(fig, ax, im, lags, cbar_loc=None, xl=xlabel, yl=ylabel)
    # ax.text(0, .9, f'{roi1} | {roi2}', transform=ax.transAxes)
    ax.set_xlabel(xlabel, loc='left')
    ax.set_ylabel(ylabel, loc='bottom')

cbar = fig.colorbar(im, ax=axes, orientation='horizontal', shrink=0.25)
cbar.set_label('ISE (r)')
cbar.set_ticks([-mx, mx])

fig.savefig(f'{outdir}/fig4_heatmaps.svg')
plt.show()

In [None]:
# plot roixroi spk-list together
w = 7.24 / 2 - .5
fig, axes = plt.subplots(1, 2, figsize=(w*1.75, 1.68), sharex=True)

mx = 0.3

ax = axes[0]
Mmax = Ms.max((2,3))
print(Mmax.min(), Mmax.max())
Mmax[np.triu_indices_from(Mmax, 1)] = np.nan
# mx = np.nanmax(Mmax)
im = ax.imshow(Mmax, cmap='Blues', vmin=0, vmax=mx)

ax.set(xlabel='speaker ROI', ylabel='speaker ROI')
ax.set_xticks(range(len(rois)))
ax.set_yticks(range(len(rois)))
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)
ax.set_xticklabels(rois, rotation=45)
ax.set_yticklabels(rois)
fig.colorbar(im, ax=ax)

# x, y = np.nonzero(np.tril(sigsS.any((2, 3))))
# ax.scatter(y, x, marker='*', color='yellow', s=4)

# listener
ax = axes[1]
Mmax = Ml.max((2,3))
print(Mmax.min(), Mmax.max())
Mmax[np.tril_indices_from(Mmax, -1)] = np.nan
# mx = np.nanmax(Mmax)
im = ax.imshow(Mmax, cmap='Reds', vmin=0, vmax=mx)

ax.set(xlabel='listener ROI', ylabel='listener ROI')
ax.set_xticks(range(len(rois)))
ax.set_yticks(range(len(rois)))
ax.spines.right.set_visible(True)
ax.spines.top.set_visible(True)
ax.set_xticklabels(rois, rotation=45)
ax.set_yticklabels(rois)
fig.colorbar(im, ax=ax)

# x, y = np.nonzero(np.triu(sigsL.any((2, 3))))
# ax.scatter(y, x, marker='*', color='yellow', s=4)

fig.suptitle('within-speaker and within-listener ISE')

fig.savefig(f'{outdir}/figS-ise-within-subjects.svg')
plt.show()

## schematic figure

In [None]:
def rescale_col(df, col, vmin=0, vmax=1):
    x, y = df[col].min(), df[col].max()
    return (df[col] - x) / (y - x) * (vmax - vmin) + vmin

In [None]:
def color_col(df, cmap='Greys', vmin=None, vmax=None):
    vmin = df.min() if vmin is None else vmin
    vmax = df.max() if vmax is None else vmax
    colormap = mpl.cm.get_cmap(cmap)
    normer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    return df.apply(lambda x: mpl.colors.to_hex(colormap(normer(x))))

In [None]:
maxesS = Ms.max((2,3))
sumsS = sigsS.sum((2,3))
anysigS = np.tril(sigsS.any((2,3)), -1)
maxesP = Ml.max((2,3))
sumsP = sigsL.sum((2,3))
anysigP = np.tril(sigsL.any((2,3)), -1)
maxesSP = M.max((2,3))
anysigSP = sigs.any((2,3))
sumsSP = sigs.sum((2,3))

In [None]:
records = []
for i, j in zip(*anysigS.nonzero()):
    records.append(('spk-', rois[j], rois[i], maxesS[i, j], sumsS[i, j]))
for i, j in zip(*anysigP.nonzero()):
    records.append(('lst-', rois[j], rois[i], maxesP[i, j], sumsP[i, j]))
for i, j in zip(*anysigSP.nonzero()):
    records.append(('spklst-', rois[i], rois[j], maxesSP[i, j], sumsSP[i, j]))

In [None]:
mx = .15

df = pd.DataFrame(records, columns=('mode', 'roi1', 'roi2', 'maxcorr', 'npairs'))
df['label'] = df['mode'] + df['roi1'].str.lower() + '-' + df['roi2'].str.lower()

print(df.groupby('mode')['maxcorr'].max())

# df['stroke-width'] = rescale_col(df, 'npairs', 0.5, 1.8)
# df['stroke-opacity'] = rescale_col(df, 'maxcorr', 0.1, 1)
df['stroke'] = df['maxcorr'].to_frame().apply(color_col, vmin=0, vmax=mx)

# df['stroke'] = df.groupby('mode')['maxcorr'].transform(color_col, vmin=0, vmax=0.15, cmap='Greys')
df['stroke-width'] = 1

# save to file
dfj = df[['label', 'stroke', 'stroke-width']].set_index('label')#.to_dict('index')
dfj.to_json(f'{outdir}/fig5data.json', orient='index')

df

In [None]:
!python svgtest.py ../results/figure5h.svg $outdir/fig5data.json $outdir/figure5h-mod.svg

In [None]:
fig = plt.figure(figsize=(1.863, 0.284))
# The dimensions (left, bottom, width, height) of the new Axes. All quantities are in fractions of figure width and height.
ax = fig.add_axes([0, 1, .85, .35])
cb = mpl.colorbar.ColorbarBase(ax, orientation='horizontal', cmap=plt.get_cmap('Greys'), norm=mpl.colors.Normalize(vmin=0, vmax=mx))
cb.set_label('ISE (r)')
cb.set_ticks([0, df['maxcorr'].min(), mx])
fig.savefig(f'{outdir}/figS_colorbar25.svg')
plt.show()

## Figure S4.3 - s2l and l2s

In [None]:
# All subjects
mx = .15

w = 7.24 / 2 - .5
fig, axes = plt.subplots(1, 2, figsize=(w*1.75, 1.68), sharex=True, sharey=True)

ax = axes[0]
M, ns, ms = getMaps(results, Ss, modelname=modelname, significant=sigmasks, reduce=True, weight=True, mode='s2l')
im = ax.imshow(M, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags, cbar_loc=None)
ax.set(title='Speaker to listener')
print(M.max())

ax = axes[1]
M, ns, ms = getMaps(results, Ss, modelname=modelname, significant=sigmasks, reduce=True, weight=True, mode='l2s')
im = ax.imshow(M, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags, cbar_loc=None)
ax.set(title='Listener to speaker')
print(M.max())

cbar = fig.colorbar(im, ax=axes, orientation='vertical')
cbar.set_ticks([-mx, 0, mx])

fig.savefig(f'{outdir}/fig4s-s2l-l2s.svg')
plt.show()

# Significance tests

In [None]:
# Load saved ISEs
static = 'model-gpt2-xl_maxlen-0_reg-l2'
contextual = 'model-gpt2-xl_maxlen-1024_layer-24_reg-l2'
untrained = 'model-gpt2-xl_maxlen-1024_layer-24_random_reg-l2'

method, alpha = 'fdr_bh', .01
# outdir = f'../results/paper/{sigmodelname}_method-{method}_alpha-{alpha}/{modelname}'
sigmodelname = f'model-gpt2-xl_maxlen-0_reg-l2_perm-phase_method-{method}_alpha-{alpha}'

staticM = np.load(f'../results/paper/{sigmodelname}/{static}/data-fig3-ise.npz')['observed']
contextualM = np.load(f'../results/paper/{sigmodelname}/{contextual}/data-fig3-ise.npz')['observed']
untrainedM = np.load(f'../results/paper/{sigmodelname}/{untrained}/data-fig3-ise.npz')['observed']
print(staticM.shape, contextualM.shape, untrainedM.shape)

staticM = staticM.reshape(-1, 129, 129)
contextualM = contextualM.reshape(-1, 129, 129)
untrainedM = untrainedM.reshape(-1, 129, 129)

In [None]:
# Build null distribution of differences
difference = contextualM - staticM
# difference = contextualM - untrainedM
observed = mean_diff_axis(difference)
null_distribution = permute_differences(difference, summary=mean_diff_axis, n_perms=10000)
difference.shape, observed.shape, null_distribution.shape

In [None]:
pvalues = calculate_pvalues(observed, null_distribution, alternative='two-sided')
pvalues.shape

In [None]:
method, alpha = 'fdr_bh', .01
sigs, pvals_corr, alphacSidak, alphacBonf = multipletests(pvalues.flatten(), method=method, alpha=alpha)
sigs.sum()

In [None]:
diff2 = np.ma.array(observed)
diff2.mask = ~ sigs.reshape(129, 129)

In [None]:
diff = np.ma.array(observed)
diff.mask = ~ sigs.reshape(129, 129)

fig, ax = plt.subplots()
im = plt.imshow(diff, origin='lower', cmap='coolwarm', vmin=-.1, vmax=.1)
formatim(fig, ax, im, lags, cbar_loc='v', cl='ISE (r)')
# fig.savefig(f'../results/paper/{contextual}/fig4s-contextual-static-sig-diff.svg')
fig.savefig(f'{outdir}/fig4s-contextual-untrained-sig-diff.svg')
plt.show()

In [None]:
# significant differences
mx = .10

w = 7.24 / 2 - .5
fig, axes = plt.subplots(1, 2, figsize=(w*1.75, 1.68), sharex=True, sharey=True)

ax = axes[0]
im = ax.imshow(diff, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags, cbar_loc=None)
ax.set(title='Trained vs. Untrained')

ax = axes[1]
im = ax.imshow(diff2, origin='lower', interpolation=None, cmap='coolwarm', vmin=-mx, vmax=mx)
formatim(fig, ax, im, lags, cbar_loc=None)
ax.set(title='Contextual vs. Static')

cbar = fig.colorbar(im, ax=axes, orientation='vertical')
cbar.set_label('ISE (r)')

fig.savefig(f'{outdir}/fig4s-ise-sigtests.svg')
plt.show()

# Tables

In [None]:
# Load data
dfs = []
for (sub, mode, _), result in results.items():
    df = pd.DataFrame()
    df['electrode'] = result['electrodes']
    df['roi'] = result['rois']
    df['significant'] = sigmasks[(sub, mode)]
    df.insert(0, 'mode', mode)
    df.insert(0, 'subject', sub)
    dfs.append(df)

df = pd.concat(dfs)
roi2parcell = {roi: label for label, rois in ROIS.items() for roi in rois}
df['parcell'] = df.roi.apply(lambda x: roi2parcell.get(x, None))
df.dropna(inplace=True)
df

In [None]:
# per- subject, mode, and ROI significant electrodes
col = 'parcell'
mask = df.significant
dfcounts = df[mask].groupby(['subject', 'mode', 'parcell']).electrode.count().to_frame().reset_index()
pd.pivot(dfcounts, index=['subject', 'mode'], columns=[col]).fillna(0).astype(int)

In [None]:
# Cleaned up version of previous
mode = 'prod'
mask = df.significant & (df['mode'] == mode)
dfcounts = df[mask].groupby(['subject','parcell']).electrode.count().to_frame().reset_index()
df1 = pd.pivot(dfcounts, index='subject', columns=[col]).fillna(0).astype(int)
df1.insert(0, 'mode', mode)

mode = 'comp'
mask = df.significant & (df['mode'] == mode)
dfcounts = df[mask].groupby(['subject','parcell']).electrode.count().to_frame().reset_index()
df2 = pd.pivot(dfcounts, index='subject', columns=[col]).fillna(0).astype(int)
df2.insert(0, 'mode', mode)

df3 = pd.concat((df1, df2))
df3.loc['Column_Total']= df3.sum(numeric_only=True, axis=0).astype(int)
df3.loc[:,'Row_Total'] = df3.sum(numeric_only=True, axis=1).astype(int)

del df1, df2
df3 = df3.astype(int, errors='ignore')
df3

In [None]:
# How many electrodes are significant for comp/prod?
df2 = df.groupby(['subject', 'electrode', 'mode']).significant.sum().unstack('mode')
df2['both'] = df2['comp'] & df2['prod']
df2['only_comp'] = df2['comp'] & (~df2['prod'])
df2['only_prod'] = df2['prod'] & (~df2['comp'])
df2.sum()

In [None]:
# Load data
records = []
for (sub, mode, _), result in results.items():
    roilist = sum([list(v) for v in ROIS.values()], [])
    roimask = np.in1d(result['rois'], roilist)
    sig_elecs = np.array(result['electrodes'])[sigmasks[(sub, mode)] & roimask]
    n_depth = sum([e[4:].startswith('D') for e in sig_elecs])
    n_grid = sum([e[4:].startswith('G') for e in sig_elecs])
    n_strip = len(sig_elecs) - n_depth - n_grid
    records.append((sub, mode, len(sig_elecs), n_grid, n_strip, n_depth))

df_etype = pd.DataFrame(records, columns=['sub', 'mode', 'all', 'grid', 'strip', 'depth'])
df_etype.loc['Column_Total']= df_etype.sum(numeric_only=True, axis=0).astype(int)
df_etype.loc[:,'Row_Total'] = df_etype.sum(numeric_only=True, axis=1).astype(int)
df_etype

# Fig SX

In [None]:
# Mock data
from utils import epochbin

data = np.zeros(10*512).reshape(-1, 1)
tmax, window, jump = 4*512, .250*512, .0625*512
lags = np.arange(-tmax, tmax+jump, jump).astype(int)
onsets = [5*512]
data[onsets, :] = 1
S = epochbin(data, onsets, lags, int(window)).squeeze()
S /= S.sum()
S.shape

In [None]:
MS, NS, _ = getMaps(results, Ss, modelname=modelname, speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, mode='within-prod', significant=sig, reduce=True)
MP, _, NP = getMaps(results, Ss, modelname=modelname, speakerROI=ALL_ROIs, partnerROI=ALL_ROIs, mode='within-comp', significant=sig, reduce=True)

In [None]:
fsh = 2.24*(3/4)
fs2c = (4.76, fsh)
xaxis = lags/512

fig, axes = plt.subplots(1, 2, figsize=fs2c)

ax = axes[0]
ax.plot(xaxis, S, color='black')
ax.plot(xaxis, MS, color=BLUE)
ax.plot(xaxis, MP, color=RED)
# ax.set_xlim(-1, 1)
ax.set(xlabel='lag (s)', ylabel='encoding (r)')
formatenc(ax)

ax = axes[1]
im = ax.imshow(np.outer(S, S), cmap='Greys', origin='lower')
formatim(fig, ax, im, xaxis, cbar_loc=None)
plt.show()