In [None]:
import numpy as np
import xarray as xr
import pandas as pd
from collections import OrderedDict

from pathta import Study

from frites import set_mpl_style
from frites.conn import conn_links
from myfrites.plot import plot_dist

from pingouin import ttest
from mne.stats import fdr_correction
from frites.stats import confidence_interval

import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.image as mpimg
from matplotlib.lines import Line2D

from ipywidgets import interact

# from paper_fcn import load_piid, plot_nx_rp

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()

plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 23
plt.rcParams['axes.labelsize'] = 22


# **Parameters and functions**
---
## Overall parameters

In [None]:
###############################################################################
bands = 'lga'
metrics = 'ii'
model = 'PE'
roi_order = [
    'aINS-aINS', 'dlPFC-dlPFC', 'aINS-dlPFC', 'aINS-lOFC', 'aINS-vmPFC',
    'lOFC-lOFC', 'vmPFC-vmPFC', 'dlPFC-vmPFC', 'dlPFC-lOFC', 'lOFC-vmPFC'
]
###############################################################################

st = Study('PBLT')

## Loading function

In [None]:
def load_results(from_folder):
    # load stats intra and inter
    f = st.search("relation-inter", folder=from_folder)
    assert len(f) == 1
    dt_inter = xr.load_dataset(f[0])
    f = st.search("relation-intra", folder=from_folder)
    assert len(f) == 1
    dt_intra = xr.load_dataset(f[0])

    # merge intra and inter
    dt, ci = {}, {}
    for k in dt_inter.keys():
        dt[k] = xr.concat((dt_intra[k], dt_inter[k]), 'roi')
    dt = xr.Dataset(dt).sel(roi=roi_order)

    return dt

# **I/O**
---
## Load stats on II

In [None]:
###############################################################################
from_folder = 'conn/piid/lga-PE/ii/full'
###############################################################################

dt = load_results(from_folder)

## Bin II according to the specificity of the contacts

In [None]:
###############################################################################
use_roi = [
    'aINS', 'dlPFC', 'vmPFC', 'lOFC'
]
###############################################################################

subjects = list(st.load_config('subjects.json').keys())

ii = []
for s in subjects:
    # ------------------------------ CONJUNCTION ------------------------------
    # load conjunction analysis
    f = st.search(s, folder='mi/conj/PE')
    if not len(f):
        continue
    assert len(f) == 1
    _conj = xr.load_dataset(f[0])['pv']
    _conj_mi = xr.load_dataset(f[0])['mi']

    # compute contacts specificity
    conj_bin = (_conj < .05).any('times')

    # compute specificity
    is_rew = np.logical_and(conj_bin.sel(cond='rew').data,
                            ~conj_bin.sel(cond='pun').data)
    is_pun = np.logical_and(~conj_bin.sel(
        cond='rew').data, conj_bin.sel(cond='pun').data)
    is_both = np.logical_and(conj_bin.sel(
        cond='rew').data, conj_bin.sel(cond='pun').data)
    assert np.c_[is_rew, is_pun, is_both].sum(1).max() <= 1

    contacts_spe = np.array(['ns.'] * len(conj_bin['roi']), dtype='object')
    contacts_spe[is_rew] = 'RPE'
    contacts_spe[is_pun] = 'PPE'
    contacts_spe[is_both] = 'Both'
    contacts_spe = xr.DataArray(contacts_spe, dims=(
        'roi',), coords=(conj_bin['roi'],))

    # -------------------------------- ANATOMY --------------------------------
    # load the anatomy
    _anat = st.search(s, folder='anatomy', load=True)

    # get left and right hemispheres
    hemi = []
    for k in _anat['ma_checked']:
        if 'L_' in k:
            hemi.append('Left')
        elif 'R_' in k:
            hemi.append('Right')
        else:
            hemi.append('None')
    _anat['hemisphere'] = hemi

    _anat.replace({
        'L\\_': '', 'R\\_': ''
    }, regex=True, inplace=True)

    # load bad channels
    bad_ch = st.search(s, folder='bad_channels', load=True)['ch_names']
    keep = [k not in bad_ch for k in _anat['contact']]
    _anat = _anat[keep].reset_index()

    # keep only the important roi
    keep = [k in use_roi for k in _anat['ma_checked']]
    if not np.any(keep):
        continue
    _anat = _anat[keep].reset_index()[['contact', 'ma_checked', 'hemisphere']]
    _anat['subject'] = [s] * len(_anat)

    # select recordings in the conjunction analysis
    np.testing.assert_array_equal(
        _anat['ma_checked'].values, contacts_spe['roi'].data
    )

    # ----------------------------------- II ----------------------------------
    # load the ii
    f = st.search(
        s, 'red-syn', folder='conn/piid/lga-PE/pid_no-stats_bias-0/full')
    if not len(f):
        continue
    assert len(f) == 1
    __ii = xr.load_dataset(f[0])[['syn', 'red']]

    # keep only inter-roi
    keep = []
    for k in __ii['roi'].data:
        _r1, _r2 = k.split('-')
        keep.append(_r1 != _r2)
    __ii = __ii.sel(roi=keep)
    __ii = __ii['syn'] - __ii['red']

    # ------------------------------- ANAT + II -------------------------------
    (x_s, x_t), roi_st = conn_links(
        _anat['ma_checked'].values, roi_relation='inter',
        hemisphere=_anat['hemisphere'].values, hemi_links='intra',
        verbose=False
    )
    np.testing.assert_array_equal(__ii['roi'].data, roi_st)

    # ------------------------------- CONJ + II -------------------------------
    # II{spe1-spe2} = II{spe2-spe1}
    spe_st = np.sort(np.c_[contacts_spe.data[x_s],
                     contacts_spe.data[x_t]], axis=1)
    # spe_st = np.c_[contacts_spe.data[x_s], contacts_spe.data[x_t]]
    spe_st = [f"{__s}-{__t}" for __s, __t in spe_st]
    __ii = __ii.assign_coords(spe=('roi', spe_st))
    ii.append(__ii)


ii = xr.concat(ii, 'roi')


## Utility functions and parameters

In [None]:
###############################################################################
yscale = 1.2
cis = 'sem'
sigma = 0.1

ii_repl = {
    "PPE-PPE": "PPE-PPE",
    "RPE-RPE": "RPE-RPE",
    "PPE-RPE": "PPE-RPE",
    "Both-Both": "Mixed",
    "Both-RPE": "Mixed",
    "Both-PPE": "Mixed"
}
###############################################################################


def lineplot(x, y, ci, cm):
    npts = len(x)
    ax.set_prop_cycle(color=[cm.to_rgba(i) for i in y])
    for i in range(npts - 1):
        ax.plot(x[i:i + 2], y[i:i + 2])
        low = ci.sel(bound='low').data[i:i+2]
        high = ci.sel(bound='high').data[i:i+2]
        plt.fill_between(
            x[i:i+2], low, high, alpha=.5, zorder=1
        )


def plt_img(name, ax):
    img_intra = mpimg.imread(f'images/{name}.png')
    plt.imshow(img_intra)
    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])


# effect-size, p-values and ci selection
sstr = "%.3f" % sigma
es = dt['mi_']
pv = dt[f"pv__{sstr[2::]}"]
esci = dt[f"ci_"].sel(ci=cis)


# subselect across roi interactions
keep = []
for r in es['roi'].data:
    _r1, _r2 = r.split('-')
    keep.append(_r1 != _r2)
es = es.sel(roi=keep)
esci = esci.sel(roi=keep)
pv = pv.sel(roi=keep)


# get shape and coords
n_times, n_roi = es.shape
times = dt['times'].data
roi = dt['roi'].data

# significant clusters
maxmax = esci.data.min()
clusters = xr.DataArray(
    np.full(pv.shape, 1.1 * maxmax), dims=pv.dims, coords=pv.coords
)
clusters.data[pv.data >= 0.05] = np.nan

# build the colormap
# vmin, vmax = es.data.min(), es.data.max()
vmin, vmax = np.percentile(es, [10, 90])
minmax = max(abs(vmin), abs(vmax))
vmin, vmax = -minmax, minmax
cmap = plt.get_cmap('coolwarm')
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
cm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)


## Stat functions

In [None]:
def pv_to_star(pv):
    if isinstance(pv, (int, float)): pv = [pv]
    stars = []
    for p in pv:
        if p > 0.05:
            _star = 'ns'
        if p <= 0.05:
            _star = '*'
        if p <= 0.01:
            _star = '**'
        if p <= 0.001:
            _star = '***'
        stars.append(_star)
    return stars

def _ttest(x, popmean=0, mult=1):
    """
    mult = patch for bad CI
    """
    df = ttest(mult * x, popmean)
    ci = df['CI95%'].values[0].astype(float) / mult
    ci = ["{:.1e}".format(k) for k in ci]

    df['CI95%'] = [ci]
    return df

def df_ttest(df, groupby, colname, popmean=0, col=None, mult=1):
    # compute t-test for each row
    stats = df.groupby(groupby).apply(_ttest, popmean=popmean, mult=mult)

    # correct p-values for multiple comparisons
    stats['p-val-corrected'] = fdr_correction(stats['p-val'].values)[1]

    # get stars
    pv = stats['p-val-corrected'].values
    stats["Signi"] = pv_to_star(pv)
    return stats


# pd.set_option('display.float_format', lambda x: '%.3f' % x)
# x = np.random.uniform(-10e-6, 10e-6, (1000,))
# _ttest(x, mult=10e6)#['CI95%'].values


# **Figures**
---
## Main text Fig. 5

In [None]:
###############################################################################
# roi to plot
use_roi = 'dlPFC-vmPFC'

# text parameters
nb_kw = dict(size=27, weight='bold')

# links to keep for the binned II
keep_ii = [
    "PPE-RPE",
    "Mixed"
]
palette = {
    "PPE-RPE": "C6",
    "Mixed": "C2" 
}

ii_times = tuple([.25, .6])
###############################################################################


times = ii['times'].data
times_range = times[~np.isnan(clusters.sel(roi=use_roi).data)]
print(
    f"Significant II between [{times_range[0]}; {times_range[-1]}]"
)

# __________________________________ FIGURE ___________________________________
fig = plt.figure(figsize=(22, 6))

# ________________________________ II + STATS _________________________________
ax = plt.subplot(1, 3, 1)
fg = es.sel(roi=use_roi).plot(
    x='times', add_legend=False, color='C3', lw=2., ax=ax
)
plt.xticks([-.5, 0., .5, 1., 1.5])
plt.ylim(-0.00025, 0.00025)
plt.xlabel('Times [s]')
plt.ylabel('II [bits]')
plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
plt.title('')
# plt.title(use_roi, fontweight='bold')
plt.axvline(0., color='C3')
ax.text(*tuple([-.2, 1.05]), 'A', transform=ax.transAxes, **nb_kw)
lineplot(times, es.sel(roi=use_roi).data, esci.sel(roi=use_roi), cm)

# plot significant clusters
ln, = plt.plot(times, clusters.sel(roi=use_roi).data / 2, lw=8, color='C5')
ln.set_solid_capstyle('round')

# create the legend
custom_lines = [
    Line2D([0], [0], color="C5", lw=6, solid_capstyle='round')
]
titles = [r"$II_{PE}$"]
# plt.legend(
#     custom_lines, titles, ncol=1, fontsize=20,
#     title="Significant cluster of II (p<0.05)",
#     title_fontproperties=dict(weight='bold', size=20), loc=3
# )

# add colorbar
## ax_cb = fig.add_axes([.14, -0.02, .2, 0.04])
# ax_cb = fig.add_axes([.16, 0.21, .16, 0.04])
# cb = mpl.colorbar.ColorbarBase(
#     ax_cb, cmap=cmap, orientation='horizontal'
# )
# cb.set_ticks([0, 1])
# cb.set_ticklabels(['Redundancy', 'Synergy'])
## cb.set_ticklabels(['Redundancy-\ndominated', 'Synergy-\ndominated'])

# ________________________________ BINNED II __________________________________

# select binned II + replace local specificity names
ii_roi = ii.sel(roi=use_roi).set_index(roi='spe')
ii_roi = ii_roi.sel(roi=['ns.' not in k for k in ii_roi['roi'].data])
ii_roi['roi'] = [ii_repl[k] for k in ii_roi['roi'].data]

ax = plt.subplot(1, 3, 2)
__min, __max = 0, 0
for n_s, spe in enumerate(keep_ii):
    plt_spe = ii_roi.sel(roi=[k == spe for k in ii_roi['roi'].data])

    # plot the mean
    plt_spe.mean('roi').plot(
        x='times', color=palette[spe], lw=3,
        label=r"$II_{%s | %s}$" % (use_roi, spe)
    )

    # compute ci
    ci = confidence_interval(
        plt_spe, axis='roi', cis='sem', verbose=False).squeeze()
    plt.fill_between(
        ci['times'].data, ci.sel(bound='low').data, ci.sel(bound='high').data,
        color=palette[spe], alpha=.2
    )
    __min = min(__min, ci.data.min())
    __max = max(__max, ci.data.max())
__minmax = max(abs(__min), abs(__max))
plt.axvline(0., color='C3')
plt.fill_betweenx(
    np.linspace(__min, __max, 10), ii_times[0], ii_times[1], color='C5', alpha=.1
)
plt.xticks([-.5, 0., .5, 1., 1.5])
plt.ylim(-__minmax, __minmax)
plt.xlabel('Times [s]')
plt.ylabel('II [bits]')
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
custom_lines = [
    Line2D([0], [0], color=list(palette.values())[0], lw=6, solid_capstyle='round'),
    Line2D([0], [0], color=list(palette.values())[1], lw=6, solid_capstyle='round')
]
# plt.legend(
#     custom_lines, [r'$II_{%s|PPE-RPE}$' % use_roi, r'$II_{%s|Mixed}$' % use_roi],
#     bbox_to_anchor=(1.3, -.1), fontsize=20,
# )
plt.legend(
    custom_lines, [r'$II_{PPE-RPE}$', r'$II_{Mixed}$'],
    loc=3, fontsize=20,
)
ax.text(*tuple([-.2, 1.05]), 'B', transform=ax.transAxes, **nb_kw)

# ________________________________ II DIST __________________________________

# temporal selection of the II
ii_roi_no_time = ii_roi.sel(times=slice(*ii_times)).mean('times')
df_ii = ii_roi_no_time.to_dataframe('II').reset_index()

# stats on the selection
ii_stats = df_ttest(
    df_ii.set_index('roi').reset_index(), 'roi', 'II'
).loc[keep_ii]

# plot the distribution
ax = plt.subplot(1, 3, 3)
plot_dist(
    data=df_ii.set_index('roi').loc[keep_ii].reset_index(),
    x='roi', y='II', kw_all=dict(palette=palette.values()),
)
plt.ylabel('II [bits]')
ax.text(*tuple([-.2, 1.05]), 'C', transform=ax.transAxes, **nb_kw)
plt.ylim(-0.001, 0.002)
plt.xlabel('Local specificity')
plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

yref = -.0006
for n_k, k in enumerate(ii_stats.index):
    plt.plot([n_k - .4, n_k + .4], [yref] * 2, color='k', lw=1)
    plt.plot([n_k - .4] * 2, [yref, yref + abs(yref * .07)], color='k', lw=1)
    plt.plot([n_k + .4] * 2, [yref, yref + abs(yref * .07)], color='k', lw=1)
    plt.text(n_k, yref - abs(yref * .3), ii_stats.loc[k, 'Signi'], ha='center',
             color='k', fontsize=22)


# export the figure
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']
fig.savefig(f'{save_to}/fig_ii_pe_across-only_{use_roi}.png', **cfg_export)

ii_stats


## Supp: II between all pairs

In [None]:
fg = es.plot(
    x='times', col='roi', col_wrap=3, add_legend=False, color='C3', lw=2.,
    sharex=False, size=4
)
fig = plt.gcf()
plt.subplots_adjust(wspace=0.15)
fig = plt.gcf()
# plt.xticks([-.5, 0., .5, 1., 1.5])
# plt.xlim(-.5, 1.5)
plt.ylim(-0.0004, 0.0004)

for n_r, r in enumerate(es['roi'].data):
    ax = np.ravel(fg.axs)[n_r]
    plt.sca(ax)
    plt.xticks([-.5, 0., .5, 1., 1.5])
    plt.title(r, fontweight='bold')
    lineplot(times, es.sel(roi=r).data, esci.sel(roi=r), cm)

    # plot significant clusters
    ln, = plt.plot(times, clusters.sel(roi=r).data, lw=5, color='C5')
    ln.set_solid_capstyle('round')

    plt.axvline(0., color='C3')
    if n_r >= 3:
        plt.xlabel('Times [s]')
    if n_r in [0, 3]:
        plt.ylabel('II [bits]')

# create the legend
custom_lines = [
    Line2D([0], [0], color="C5", lw=6, solid_capstyle='round')
]
titles = [r"$II_{PE}$"]
plt.legend(
    custom_lines, titles, ncol=1, bbox_to_anchor=(.75, 0.01), fontsize=20,
    bbox_transform=fig.transFigure, title="Significant cluster of II (p<0.05)",
    title_fontproperties=dict(weight='bold', size=20)
)

# add colorbar
ax_cb = fig.add_axes([1., 0.25, 0.02, .5])
cb = mpl.colorbar.ColorbarBase(
    ax_cb, cmap=cmap, orientation='vertical'
)
cb.set_ticks([0, 1])
cb.set_ticklabels(['Redundancy-\ndominated', 'Synergy-\ndominated'])

fig.savefig(f'{save_to}/supp_ii_pe_across-only.png', **cfg_export)


## Supp: full binned II 

In [None]:
ii_roi


In [None]:
###############################################################################
# roi to plot
use_roi = 'dlPFC-vmPFC'

# text parameters
nb_kw = dict(size=27, weight='bold')

# links to keep for the binned II
keep_ii = [
    "PPE-RPE",
    "Mixed",
    "PPE-PPE",
    "RPE-RPE",
]
palette = {
    "PPE-RPE": "C6",
    "Mixed": "C2",
    "PPE-PPE": "C0",
    "RPE-RPE": "C1",
}

ii_times = tuple([.25, .6])
###############################################################################


times = ii['times'].data
times_range = times[~np.isnan(clusters.sel(roi=use_roi).data)]
print(
    f"Significant II between [{times_range[0]}; {times_range[-1]}]"
)

# __________________________________ FIGURE ___________________________________
fig = plt.figure(figsize=(16, 6))

# ________________________________ BINNED II __________________________________

# select binned II + replace local specificity names
ii_roi = ii.sel(roi=use_roi).set_index(roi='spe')
ii_roi = ii_roi.sel(roi=['ns.' not in k for k in ii_roi['roi'].data])
ii_roi['roi'] = [ii_repl[k] for k in ii_roi['roi'].data]

ax = plt.subplot(1, 2, 1)
__min, __max = 0, 0
for n_s, spe in enumerate(keep_ii):
    plt_spe = ii_roi.sel(roi=[k == spe for k in ii_roi['roi'].data])

    # plot the mean
    plt_spe.mean('roi').plot(
        x='times', color=palette[spe], lw=2,
        label=r"$II_{%s | %s}$" % (use_roi, spe)
    )

    # compute ci
    ci = confidence_interval(
        plt_spe, axis='roi', cis='sem', verbose=False).squeeze()
    plt.fill_between(
        ci['times'].data, ci.sel(bound='low').data, ci.sel(bound='high').data,
        color=palette[spe], alpha=.2
    )
    __min = min(__min, ci.data.min())
    __max = max(__max, ci.data.max())
__minmax = max(abs(__min), abs(__max))
plt.axvline(0., color='C3')
plt.fill_betweenx(
    np.linspace(__min, __max, 10), ii_times[0], ii_times[1], color='C5', alpha=.1
)
plt.xticks([-.5, 0., .5, 1., 1.5])
# plt.ylim(-__minmax, __minmax)
plt.xlabel('Times [s]')
plt.ylabel('II [bits]')
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
custom_lines = [
    Line2D([0], [0], color=col, lw=6, solid_capstyle='round') for col in palette.values()
]
titles = [r'$II_{%s|%s}$' % (use_roi, c) for c in keep_ii]
plt.legend(
    custom_lines, titles, loc=3, fontsize=20, ncol=2, bbox_to_anchor=(.4, -.5)
)
ax.text(*tuple([-.2, 1.05]), 'A', transform=ax.transAxes, **nb_kw)

# ________________________________ II DIST __________________________________

# temporal selection of the II
ii_roi_no_time = ii_roi.sel(times=slice(*ii_times)).mean('times')
df_ii = ii_roi_no_time.to_dataframe('II').reset_index()

# stats on the selection
ii_stats = df_ttest(
    df_ii.set_index('roi').loc[keep_ii].reset_index(), 'roi', 'II', mult=10000
).loc[keep_ii]
# pd.reset_option('display.float_format')
# pd.set_option('display.precision', 2)
pd.set_option('display.float_format', lambda x: '%.3f' % x)

# plot the distribution
ax = plt.subplot(1, 2, 2)
plot_dist(
    data=df_ii.set_index('roi').loc[keep_ii].reset_index(),
    x='roi', y='II', kw_all=dict(palette=palette.values()),
)
plt.ylabel('II [bits]')
ax.text(*tuple([-.2, 1.05]), 'B', transform=ax.transAxes, **nb_kw)
plt.ylim(-0.001, 0.002)
plt.xlabel('Local specificity')
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))

yref = -.0006
for n_k, k in enumerate(ii_stats.index):
    plt.plot([n_k - .4, n_k + .4], [yref] * 2, color='k', lw=1)
    plt.plot([n_k - .4] * 2, [yref, yref + abs(yref * .07)], color='k', lw=1)
    plt.plot([n_k + .4] * 2, [yref, yref + abs(yref * .07)], color='k', lw=1)
    plt.text(n_k, yref - abs(yref * .3), ii_stats.loc[k, 'Signi'], ha='center',
             color='k', fontsize=22)


# export the figure
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']
fig.savefig(f'{save_to}/supp_ii_specificity_{use_roi}.png', **cfg_export)

# 
ii_stats


In [None]:
ii_report = ii_stats.copy()

# attach star to p-values
ii_report['p-val'] = [f"{np.round(p, 3)}{s}" if s != 'ns' else f"{np.round(p, 4)}" for p,
                      s in zip(ii_report['p-val'].values, pv_to_star(ii_report['p-val'].values))]
ii_report['p-val-corrected'] = [f"{np.round(p, 3)}{s}" if s != 'ns' else f"{np.round(p, 4)}" for p,
                      s in zip(ii_report['p-val-corrected'].values, pv_to_star(ii_report['p-val-corrected'].values))]

ii_report = ii_report.reset_index().drop(columns=[
    'alternative', 'power', 'level_1', 'cohen-d', "Signi"
    ] ).rename(columns={
    "T": "T-value",
    "roi": "Specificity",
    "p-val": "P-value",
    "p-val-corrected": "P-value (FDR corrected)",
}).set_index('Specificity').reindex(columns=[
    "T-value", "P-value", "P-value (FDR corrected)", "dof", "CI95%"
]).round(3)

pd.set_option('display.float_format', lambda x: '%.2f' % x)
ii_report.to_excel(os.path.join(save_to, "ttest_stats.xlsx"))


## Within and across

In [None]:
# ###############################################################################
# yscale = 1.2
# cis = 'sem'
# sigma = 0.1
# ###############################################################################

# def lineplot(x, y, ci, cm):
#     npts = len(x)
#     ax.set_prop_cycle(color=[cm.to_rgba(i) for i in y])
#     for i in range(npts - 1):
#         ax.plot(x[i:i + 2], y[i:i + 2])
#         low = ci.sel(bound='low').data[i:i+2]
#         high = ci.sel(bound='high').data[i:i+2]
#         plt.fill_between(
#             x[i:i+2], low, high, alpha=.5, zorder=1
#     )

# # effect-size, p-values and ci selection
# sstr = "%.3f" % sigma
# es = dt['mi_']
# pv = dt[f"pv__{sstr[2::]}"]
# esci = dt[f"ci_"].sel(ci=cis)

# # get shape and coords
# n_times, n_roi = es.shape
# times = dt['times'].data
# roi = dt['roi'].data

# # significant clusters
# maxmax = esci.data.min()
# clusters = xr.DataArray(
#     np.full(pv.shape, 1.1 * maxmax), dims=pv.dims, coords=pv.coords
# )
# clusters.data[pv.data >= 0.05] = np.nan

# # build the colormap
# # vmin, vmax = es.data.min(), es.data.max()
# vmin, vmax = np.percentile(es, [10, 90])
# minmax = max(abs(vmin), abs(vmax))
# vmin, vmax = -minmax, minmax
# cmap = plt.get_cmap('coolwarm')
# norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
# cm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)

# fg = es.plot(x='times', col='roi', col_wrap=5, add_legend=False, color='C3', lw=2., sharex=False)
# fig = plt.gcf()
# plt.subplots_adjust(wspace=0.15)
# fig = plt.gcf()
# plt.xticks([-.5, 0., .5, 1., 1.5])
# # plt.xlim(-.5, 1.5)

# for n_r, r in enumerate(es['roi'].data):
#     ax = np.ravel(fg.axes)[n_r]
#     plt.sca(ax)
#     plt.title(r, fontweight='bold')
#     lineplot(times, es.sel(roi=r).data, esci.sel(roi=r), cm)

#     # plot ci
#     # plt.fill_between(
#     #     times, esci.sel(bound='low', roi=r).data,
#     #     esci.sel(bound='high', roi=r).data,
#     #     color='C3', alpha=.2, zorder=-1000
#     # )

#     # plot significant clusters
#     ln, = plt.plot(times, clusters.sel(roi=r).data, lw=5, color='C5')
#     ln.set_solid_capstyle('round')

#     plt.axvline(0., color='C3')
#     if n_r >= 5: plt.xlabel('Times (s)')
#     if n_r in [0, 5]: plt.ylabel('II (bits)')

# # create the legend
# custom_lines = [
#     Line2D([0], [0], color="C5", lw=6, solid_capstyle='round')
# ]
# titles = [r"$II_{PE}$"]
# plt.legend(
#     custom_lines, titles, ncol=1, bbox_to_anchor=(.73, 0.01), fontsize=20,
#     bbox_transform=fig.transFigure, title="Significant cluster of II (p<0.05)",
#     title_fontproperties=dict(weight=None, size=20)
# );

# # add colorbar
# ax_cb = fig.add_axes([1., 0.25, 0.02, .5])
# cb = mpl.colorbar.ColorbarBase(
#     ax_cb, cmap=cmap, orientation='vertical'
# )
# cb.set_ticks([0, 1])
# cb.set_ticklabels(['Redundancy', 'Synergy'])

# # annotations
# height = 1.22
# kw_ann = dict(
#     horizontalalignment='center', verticalalignment='center',
#     arrowprops=dict(arrowstyle="<|-|>", color='C3'), fontsize=18,
#     xycoords='figure fraction'
# )

# # within : arrow, text, image
# plt.gca().annotate(
#     '', xy=(0.086, height), xytext=(0.445, height), **kw_ann
# )
# plt.figtext(0.138, 1.08, 'Within ROI interactions', fontsize=22, color='C3',
#             fontweight=None)
# ax = fig.add_axes([.341, 1.05, .1, .1])
# plt_img('links_intra', ax)

# # across : arrow, text, image
# plt.gca().annotate(
#     '', xy=(0.477, height), xytext=(.98, height), **kw_ann
# )
# plt.figtext(0.568, 1.08, 'Across ROI interactions', fontsize=22, color='C3',
#             fontweight=None);
# ax = fig.add_axes([.802, 1.05, .1, .1])
# plt_img('links_inter', ax)


In [None]:
# save_to = config['export']['save_to']
# cfg_export = config['export']['cfg']

# fig.savefig(f'{save_to}/fig_ii_pe.png', **cfg_export)

### Lineplot **without** color gradient

In [None]:
# ###############################################################################
# yscale = 1.2
# cis = 90
# sigma = 0.1
# ###############################################################################

# # effect-size, p-values and ci selection
# sstr = "%.3f" % sigma
# es = dt['mi_']
# pv = dt[f"pv__{sstr[2::]}"]
# esci = ci[f"ci_"].sel(cis=cis)

# # get shape and coords
# n_times, n_roi = es.shape
# times = dt['times'].data
# roi = dt['roi'].data

# # significant clusters
# maxmax = esci.data.min()
# clusters = xr.DataArray(
#     np.full(pv.shape, 1.1 * maxmax), dims=pv.dims, coords=pv.coords
# )
# clusters.data[pv.data >= 0.05] = np.nan

# fg = es.plot(x='times', col='roi', col_wrap=5, add_legend=False, color='C3')
# fig = plt.gcf()
# plt.xticks([-.5, 0., .5, 1.])
# plt.xlim(-.5, 1.5)

# for n_r, r in enumerate(es['roi'].data):
#     plt.sca(np.ravel(fg.axes)[n_r])
#     plt.title(r, fontweight='bold')

#     # plot ci
#     plt.fill_between(
#         times, esci.sel(bound='low', roi=r).data,
#         esci.sel(bound='high', roi=r).data,
#         color='C3', alpha=.2, zorder=-1000
#     )

#     # plot significant clusters
#     ln, = plt.plot(times, clusters.sel(roi=r).data, lw=5, color='C3')
#     solid_capstyle='round'

#     plt.axvline(0., color='C3')
#     if n_r >= 5: plt.xlabel('Times (s)')
#     if n_r in [0, 5]: plt.ylabel('II (bits)')

# # # create the legend
# # custom_lines = [
# #     Line2D([0], [0], color="C1", lw=6),
# #     Line2D([0], [0], color="C0", lw=6),
# #     Line2D([0], [0], color="C5", lw=6)
# # ]
# # titles = [r"$II_{RPE}$", r"$II_{PPE}$", r"$II_{RPE} \neq II_{PPE}$"]
# # plt.legend(
# #     custom_lines, titles, ncol=5, bbox_to_anchor=(.73, 0.01), fontsize=20,
# #     bbox_transform=fig.transFigure, title="Significant cluster of II (p<0.05)",
# #     title_fontproperties=dict(weight=None, size=20)
# # );    

# # # annotations
# # height = 1.22
# # kw_ann = dict(
# #     horizontalalignment='center', verticalalignment='center',
# #     arrowprops=dict(arrowstyle="<|-|>", color='C3'), fontsize=18,
# #     xycoords='figure fraction'
# # )
# # plt.gca().annotate(
# #     '', xy=(0.096, height), xytext=(0.445, height), **kw_ann
# # )
# # plt.gca().annotate(
# #     '', xy=(0.467, height), xytext=(1., height), **kw_ann
# # )
# # plt.figtext(0.118, 1.08, 'Within ROI connections', fontsize=22, color='C3',
# #             fontweight=None)
# # plt.figtext(0.538, 1.08, 'Across ROI connections', fontsize=22, color='C3',
# #             fontweight=None);

# # ax = fig.add_axes([.321, 1.05, .1, .1])
# # plt_img('links_intra', ax)

# # ax = fig.add_axes([.772, 1.05, .1, .1])
# # plt_img('links_inter', ax)
