In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.image as mpimg
from matplotlib.lines import Line2D

from ipywidgets import interact

# from paper_fcn import load_piid, plot_nx_rp

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()

plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 23
plt.rcParams['axes.labelsize'] = 22


---
# **Parameters and functions**
## Overall parameters

In [None]:
###############################################################################
bands = 'lga'
metrics = 'ii'
model = 'PE'
roi_order = [
    'aINS-aINS', 'dlPFC-dlPFC', 'aINS-dlPFC', 'aINS-lOFC', 'aINS-vmPFC',
    'lOFC-lOFC', 'vmPFC-vmPFC', 'dlPFC-vmPFC', 'dlPFC-lOFC', 'lOFC-vmPFC'
]
###############################################################################

st = Study('PBLT')

## Loading function

In [None]:
def load_results(from_folder):
    # load stats intra and inter
    f = st.search("relation-inter", folder=from_folder)
    assert len(f) == 1
    dt_inter = xr.load_dataset(f[0])
    f = st.search("relation-intra", folder=from_folder)
    assert len(f) == 1
    dt_intra = xr.load_dataset(f[0])

    # merge intra and inter
    dt, ci = {}, {}
    for k in dt_inter.keys():
        dt[k] = xr.concat((dt_intra[k], dt_inter[k]), 'roi')
    dt = xr.Dataset(dt).sel(roi=roi_order)

    return dt


---
# **I/O**

In [None]:
###############################################################################
from_folder = 'conn/piid/lga-PE/ii/cond'
###############################################################################

dt = load_results(from_folder)

---
# **Lineplot**
## Plotting functions

In [None]:
def plt_img(name, ax):
    img_intra = mpimg.imread(f'images/{name}.png')
    plt.imshow(img_intra)
    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

## Actual figure
### Within and across roi interactions

In [None]:
###############################################################################
yscale = 1.2
cis = 'sem'
sigma = 0.001
nb_kw = dict(size=35, weight='bold')
###############################################################################

# effect-size, p-values and ci selection
sstr = "%.3f" % sigma
es = dt[[f"mi_{s}" for s in ['pun', 'rew']]].to_array('cond')
pv_cond = dt[[f"pv_{s}_{sstr[2::]}" for s in ['pun', 'rew']]].to_array('cond')
pv_diff = dt[f"pv_rew!=pun_{sstr[2::]}"].expand_dims('cond', axis=0)
pv_diff['cond'] = ['diff']
pv = xr.concat((pv_cond, pv_diff), 'cond')
esci = dt[[f"ci_{s}" for s in ['pun', 'rew']]].to_array('cond').sel(ci=cis)
esci['cond'] = es['cond'] = ['pun', 'rew']
pv['cond'] = ['pun', 'rew', 'diff']

# get shape and coords
_, n_times, n_roi = es.shape
times = dt['times'].data
roi = dt['roi'].data

# significant clusters
maxmax = esci.data.min()
clusters = xr.DataArray(
    np.full(pv.shape, maxmax), dims=pv.dims, coords=pv.coords
)
clusters[0, ...] = 1.17 * maxmax
clusters[1, ...] = 1.1 * maxmax
clusters[2, ...] = 1.245 * maxmax
clusters.data[pv.data >= 0.05] = np.nan

fg = es.plot(
    x='times', col='roi', col_wrap=5, hue='cond', add_legend=False, sharex=False
)
plt.subplots_adjust(wspace=0.15, hspace=0.5, top=0.96)
fig = plt.gcf()
print(fig.get_size_inches())
0/0
# plt.xticks([-.5, 0., .5, 1., 1.5])
# plt.xlim(-.5, 1.5)

for n_r, r in enumerate(es['roi'].data):
    plt.sca(np.ravel(fg.axs)[n_r])
    plt.xticks([-.5, 0., .5, 1., 1.5])
    plt.title(r, fontweight='bold')
    for n_c, cond in enumerate(['pun', 'rew', 'diff']):
        # plot ci
        if cond in ['rew', 'pun']:
            plt.fill_between(
                times, esci.sel(bound='low', roi=r, cond=cond).data,
                esci.sel(bound='high', roi=r, cond=cond).data,
                color=f'C{n_c}', alpha=.2, zorder=-1000)

        # plot significant clusters
        if cond == 'rew': _color = 'C1'
        elif cond == 'pun': _color = 'C0'
        else: _color = 'C5'

        ln, = plt.plot(times, clusters.sel(roi=r, cond=cond).data,
                       lw=5, color=_color)
        ln.set_solid_capstyle('round')

    plt.axvline(0., color='C3')
    if n_r >= 5: plt.xlabel('Times [s]')
    if n_r in [0, 5]: plt.ylabel('II [bits]')

# create the legend
custom_lines = [
    Line2D([0], [0], color="C1", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C5", lw=6, solid_capstyle='round')
]
titles = [r"$II_{RPE}$", r"$II_{PPE}$", r"$II_{RPE} \neq II_{PPE}$"]
plt.legend(
    custom_lines, titles, ncol=5, bbox_to_anchor=(.73, 0.01), fontsize=20,
    bbox_transform=fig.transFigure, title="Significant cluster of II (p<0.05)",
    title_fontproperties=dict(weight='bold', size=20)
);    

# annotations
height = 1.22
kw_ann = dict(
    horizontalalignment='center', verticalalignment='center',
    arrowprops=dict(arrowstyle="<|-|>", color='C3'), fontsize=18,
    xycoords='figure fraction'
)
plt.gca().annotate(
    '', xy=(0.096, height), xytext=(0.445, height), **kw_ann
)
plt.gca().annotate(
    '', xy=(0.467, height), xytext=(1., height), **kw_ann
)
plt.figtext(0.105, 1.08, 'Within-region interactions', fontsize=22, color='C3',
            fontweight=None)
plt.figtext(0.52, 1.08, 'Between-regions interactions', fontsize=22, color='C3',
            fontweight=None);

ax = fig.add_axes([.331, 1.05, .1, .1])
plt_img('links_intra_b&w', ax)

ax = fig.add_axes([.807, 1.05, .1, .1])
plt_img('links_inter_b&w', ax)

# add text
pos = [-.1, 1.5]
ax = fg.axs[0, 0]
ax.text(*tuple(pos), 'A', transform=ax.transAxes, **nb_kw)
ax = fg.axs[0, 2]
ax.text(*tuple(pos), 'B', transform=ax.transAxes, **nb_kw)

## Export the figure

In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_ii_rppe.png', **cfg_export)

### Only across roi interactions

In [None]:
###############################################################################
yscale = 1.2
cis = 'sem'
sigma = 0.001
###############################################################################

# effect-size, p-values and ci selection
sstr = "%.3f" % sigma
es = dt[[f"mi_{s}" for s in ['pun', 'rew']]].to_array('cond')
pv_cond = dt[[f"pv_{s}_{sstr[2::]}" for s in ['pun', 'rew']]].to_array('cond')
pv_diff = dt[f"pv_rew!=pun_{sstr[2::]}"].expand_dims('cond', axis=0)
pv_diff['cond'] = ['diff']
pv = xr.concat((pv_cond, pv_diff), 'cond')
esci = dt[[f"ci_{s}" for s in ['pun', 'rew']]].to_array('cond').sel(ci=cis)
esci['cond'] = es['cond'] = ['pun', 'rew']
pv['cond'] = ['pun', 'rew', 'diff']

# subselect across roi interactions
keep = []
for r in es['roi'].data:
    _r1, _r2 = r.split('-')
    keep.append(_r1 != _r2)
es = es.sel(roi=keep)
esci = esci.sel(roi=keep)

# get shape and coords
_, n_times, n_roi = es.shape
times = dt['times'].data
roi = dt['roi'].data

# significant clusters
maxmax = esci.data.min()
clusters = xr.DataArray(
    np.full(pv.shape, maxmax), dims=pv.dims, coords=pv.coords
)
clusters[0, ...] = 1.17 * maxmax
clusters[1, ...] = 1.1 * maxmax
clusters[2, ...] = 1.245 * maxmax
clusters.data[pv.data >= 0.05] = np.nan

fg = es.plot(
    x='times', col='roi', col_wrap=3, hue='cond', add_legend=False,
)
plt.subplots_adjust(wspace=0.15)
fig = plt.gcf()
plt.xticks([-.5, 0., .5, 1., 1.5])
# plt.xlim(-.5, 1.5)

for n_r, r in enumerate(es['roi'].data):
    plt.sca(np.ravel(fg.axes)[n_r])
    plt.title(r, fontweight='bold')
    for n_c, cond in enumerate(['pun', 'rew', 'diff']):
        # plot ci
        if cond in ['rew', 'pun']:
            plt.fill_between(
                times, esci.sel(bound='low', roi=r, cond=cond).data,
                esci.sel(bound='high', roi=r, cond=cond).data,
                color=f'C{n_c}', alpha=.2, zorder=-1000)

        # plot significant clusters
        if cond == 'rew': _color = 'C1'
        elif cond == 'pun': _color = 'C0'
        else: _color = 'C5'

        ln, = plt.plot(times, clusters.sel(roi=r, cond=cond).data,
                       lw=5, color=_color)
        ln.set_solid_capstyle('round')

    plt.axvline(0., color='C3')
    if n_r >= 3: plt.xlabel('Times (s)')
    if n_r in [0, 3]: plt.ylabel('II (bits)')

# create the legend
custom_lines = [
    Line2D([0], [0], color="C1", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C5", lw=6, solid_capstyle='round')
]
titles = [r"$II_{RPE}$", r"$II_{PPE}$", r"$II_{RPE} \neq II_{PPE}$"]
plt.legend(
    custom_lines, titles, ncol=5, bbox_to_anchor=(.87, 0.01), fontsize=20,
    bbox_transform=fig.transFigure, title="Significant cluster of II (p<0.05)",
    title_fontproperties=dict(weight=None, size=20)
);    


In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_ii_rppe_across-only.png', **cfg_export)

---
# **Network plot**

In [None]:
fig, axs = plt.subplots(
    nrows=1, ncols=3, sharex=True, sharey=True, figsize=(16, 5.2)
)
axs = np.ravel(axs)

# SYNERGY
plt.sca(axs[0])
dt, ci = load_piid(band, model, 'syn', n_perm, sigma, savgol)
res = (dt.sel(var='pv') < 0.05).any('times')
plot_nx_rp(res)
plt.title('Synergy', fontweight='bold')

# REDUNDANCY
plt.sca(axs[1])
dt, ci = load_piid(band, model, 'redu', n_perm, sigma, savgol)
res = (dt.sel(var='pv') < 0.05).any('times')
plot_nx_rp(res)
plt.title('Redundancy', fontweight='bold')

# II
plt.sca(axs[2])
dt, ci = load_piid(band, model, 'ii', n_perm, sigma, savgol)
res = (dt.sel(var='pv') < 0.05).any('times')
plot_nx_rp(res)
plt.title('Interaction information', fontweight='bold')

plt.tight_layout()

## Export the figure

In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig1_networks.png', **cfg_export)

In [None]:
15 / 75