In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style
from frites.stats import confidence_interval

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import matplotlib.image as mpimg

from ipywidgets import interact

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()

plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['axes.labelsize'] = 22


---
# **PID per condition**
## I/O
### Load the information decomposition (PID)

In [None]:
###############################################################################
band = 'lga'
model = 'PE'

from_folder = f'conn/piid/%s-%s/pid_no-stats_bias-0/cond'
###############################################################################

st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())
from_folder %= (band, model)

uni, conn = [], []
for s in subjects:
    # load unique information
    f = st.search(s, 'var-uni', folder=from_folder, verbose=False)
    if not len(f):
        continue
    assert len(f) == 1
    _uni = xr.load_dataset(f[0]).to_array('cond')

    # load total, syn and redundant information
    f = st.search(s, 'var-tot-red-syn', folder=from_folder, verbose=False)
    assert len(f) == 1
    _conn = xr.load_dataset(f[0]).to_array('cond')

    uni.append(_uni)
    conn.append(_conn)

uni = xr.concat(uni, 'roi')
conn = xr.concat(conn, 'roi')


### Load the significant cluster (II)

In [None]:
###############################################################################
from_folder = 'conn/piid/lga-PE/ii/cond'
sigma = 0.001
###############################################################################

def load_results(from_folder):
    # load stats intra and inter
    f = st.search("relation-inter", folder=from_folder)
    assert len(f) == 1
    dt_inter = xr.load_dataset(f[0])
    f = st.search("relation-intra", folder=from_folder)
    assert len(f) == 1
    dt_intra = xr.load_dataset(f[0])

    # merge intra and inter
    dt, ci = {}, {}
    for k in dt_inter.keys():
        dt[k] = xr.concat((dt_intra[k], dt_inter[k]), 'roi')
    dt = xr.Dataset(dt)

    return dt


sstr = "%.3f" % sigma
sstr = sstr[2::]
ii = load_results(from_folder)
ii = ii[[k for k in ii.keys() if ((
    'pv_rew_' in k) or ('pv_pun_' in k)) and (sstr in k)]].to_array('cond') < 0.05
ii['cond'] = ['rew', 'pun']
# ii


## Plotting
### General plotting functions

In [None]:
###############################################################################
roi_1 = 'aINS'
roi_2 = 'dlPFC'
condition = 'pun'
###############################################################################


def single_line(name, x, color='C0', cis='sem'):
    times = x['times'].data

    # compute ci
    ci = confidence_interval(x, axis='roi', cis=cis, verbose=False).squeeze()
    low, high = ci.sel(bound='low'), ci.sel(bound='high')

    # plot mean line
    x_m = x.mean('roi').data
    plt.plot(times, x_m, lw=2, color=color, label=name)

    plt.fill_between(times, low.data, high.data, color=color, alpha=.3,
                     zorder=61)


def subplot(roi_1, roi_2, condition, legend=True):
    single_line(
        r'$Unique_{%s}$' % roi_1,
        uni.sel(roi=roi_1, cond='un_%s' % condition),
        color='C0'
    )
    single_line(
        r'$Unique_{%s}$' % roi_2,
        uni.sel(roi=roi_2, cond='un_%s' % condition),
        color='C1'
    )
    single_line(
        r'$Synergy$',
        conn.sel(roi='%s-%s' % (roi_1, roi_2), cond='syn_%s' % condition),
        color='C5'
    )
    single_line(
        r'$Redundancy$',
        conn.sel(roi='%s-%s' % (roi_1, roi_2), cond='red_%s' % condition),
        color='C4'
    )
    single_line(
        r'$Total$',
        conn.sel(roi='%s-%s' % (roi_1, roi_2), cond='tot_%s' % condition),
        color='C3'
    )

    plt.axvline(0., color='C3', lw=2)
    if legend:
        plt.legend()


def plt_img(name, ax):
    img_intra = mpimg.imread(f'images/{name}.png')
    plt.imshow(img_intra)
    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])


### Plotting results per condition for all pairs

In [None]:
###############################################################################
condition = 'rew'

roi_order = [
    'aINS-aINS', 'dlPFC-dlPFC', 'aINS-dlPFC', 'aINS-lOFC', 'aINS-vmPFC',
    'lOFC-lOFC', 'vmPFC-vmPFC', 'dlPFC-vmPFC', 'dlPFC-lOFC', 'lOFC-vmPFC'
]
###############################################################################

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(
    24, 9), sharex=True, sharey=True)
axs = np.ravel(axs)
plt.subplots_adjust(wspace=0.15)
# fig.suptitle(, fontweight='bold', fontsize=24)

title_cond = {
    'rew': 'reward', 'pun': 'punishment'
}[condition]

for n_ax, r in enumerate(roi_order):
    plt.sca(axs[n_ax])
    # plot superimposed lines
    roi_1, roi_2 = r.split('-')
    subplot(roi_1, roi_2, condition, legend=False)
    plt.title(r"$\bf{R_{1}=%s; R_{2}=%s}$" % (roi_1, roi_2), fontweight='bold')
    
    # plot significant cluster (if any)
    cl = ii.sel(cond=condition, roi=r)
    if cl.data.any():
        line = np.full((len(cl['times'])), np.nan)
        line[cl.data] = -0.001
        ln, = plt.plot(cl['times'].data, line, color='#FFA500', lw=6)
        ln.set_solid_capstyle('round')

    if n_ax in [0, 5]:
        plt.ylabel('Information (bits)')
    if n_ax > 4:
        plt.xlabel('Times (s)')
plt.xticks([-.5, 0., .5, 1., 1.5])
plt.ylim(-0.005, 0.035)

# create the legend
custom_lines = [
    Line2D([0], [0], color="C0", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C1", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C5", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C4", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C3", lw=2, solid_capstyle='round'),
    # Line2D([0], [0], color="C3", lw=6, solid_capstyle='round'),
]
titles = [
    r"$Unique_{R_{1}}$", r"$Unique_{R_{2}}$", r"$Synergy$", r"$Redundancy$",
    r"$Total$"  # r"Significant cluster of II (p<0.05)", 
]
plt.legend(
    custom_lines, titles, ncol=5, bbox_to_anchor=(.77, 0.03), fontsize=20,
    bbox_transform=fig.transFigure,
    title=f"Partial information decomposition about the {title_cond} PE",
    title_fontproperties=dict(weight='bold', size=20)
);

# annotations
height = 1.06
kw_ann = dict(
    horizontalalignment='center', verticalalignment='center',
    arrowprops=dict(arrowstyle="<|-|>", color='C3'), fontsize=18,
    xycoords='figure fraction'
)
plt.gca().annotate(
    '', xy=(0.066, height), xytext=(0.368, height), **kw_ann
)
plt.gca().annotate(
    '', xy=(0.38, height), xytext=(.842, height), **kw_ann
)
plt.figtext(0.188, 1., 'Within ROI interactions', fontsize=22, color='C3',
            fontweight=None)
plt.figtext(0.548, 1., 'Across ROI interactions', fontsize=22, color='C3',
            fontweight=None)

ax = fig.add_axes([.315, .97, .08, .08])
plt_img('links_intra', ax)

ax = fig.add_axes([.7, .97, .08, .08])
plt_img('links_inter', ax)


In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_pid-{condition}.png', **cfg_export)


### Plot specific pairs

In [None]:
###############################################################################
###############################################################################

fig = plt.figure(figsize=(10, 8))
subplot('aINS', 'vmPFC', 'rew', legend=True)
plt.xlabel('Times (s)')
plt.ylabel('Information (bits)')
plt.title(
# "Decomposing the contribution and interactions between\nthe aINS and the vmPFC about the RPE",
    r"$\bf{PID_{aINS-vmPFC}}$ about RPE",
    fontweight='bold', fontsize=24
);

save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_aINS-vmPFC_rpe.png', **cfg_export)


---
# **PID across conditions**
## I/O
### Loading th information decomposition

In [None]:
###############################################################################
band = 'lga'
model = 'PE'

from_folder = f'conn/piid/%s-%s/pid_no-stats_bias-0/full'
###############################################################################

st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())
from_folder %= (band, model)

uni, conn = [], []
for s in subjects:
    # load unique information
    f = st.search(s, 'var-uni', folder=from_folder, verbose=False)
    if not len(f):
        continue
    assert len(f) == 1
    _uni = xr.load_dataarray(f[0])

    # load total, syn and redundant information
    f = st.search(s, 'var-tot-red-syn', folder=from_folder, verbose=False)
    assert len(f) == 1
    _conn = xr.load_dataset(f[0]).to_array('comp')

    uni.append(_uni)
    conn.append(_conn)

uni = xr.concat(uni, 'roi')
conn = xr.concat(conn, 'roi')


### Loading the significant clusters of II

In [None]:
###############################################################################
from_folder = 'conn/piid/lga-PE/ii/full'
sigma = 0.1
###############################################################################

def load_results(from_folder):
    # load stats intra and inter
    f = st.search("relation-inter", folder=from_folder)
    assert len(f) == 1
    dt_inter = xr.load_dataset(f[0])
    f = st.search("relation-intra", folder=from_folder)
    assert len(f) == 1
    dt_intra = xr.load_dataset(f[0])

    # merge intra and inter
    dt, ci = {}, {}
    for k in dt_inter.keys():
        dt[k] = xr.concat((dt_intra[k], dt_inter[k]), 'roi')
    dt = xr.Dataset(dt)

    return dt


sstr = "%.3f" % sigma
sstr = sstr[2::]
ii = load_results(from_folder)
_key = [k for k in ii.keys() if (('pv_' in k) and (sstr in k))]
assert len(_key) == 1
ii = ii[_key[0]] < 0.05
# ii


## Plotting
### General plotting functions

In [None]:
def single_line(name, x, color='C0', cis='sem'):
    times = x['times'].data

    # compute ci
    ci = confidence_interval(x, axis='roi', cis=cis, verbose=False).squeeze()
    low, high = ci.sel(bound='low'), ci.sel(bound='high')

    # plot mean line
    x_m = x.mean('roi').data
    plt.plot(times, x_m, lw=2, color=color, label=name)

    plt.fill_between(times, low.data, high.data, color=color, alpha=.3,
                     zorder=61)


def subplot(roi_1, roi_2, legend=True):
    single_line(
        r'$Unique_{%s}$' % roi_1,
        uni.sel(roi=roi_1),
        color='C0'
    )
    single_line(
        r'$Unique_{%s}$' % roi_2,
        uni.sel(roi=roi_2),
        color='C1'
    )
    single_line(
        r'$Synergy$',
        conn.sel(roi='%s-%s' % (roi_1, roi_2), comp='syn'),
        color='C5'
    )
    single_line(
        r'$Redundancy$',
        conn.sel(roi='%s-%s' % (roi_1, roi_2), comp='red'),
        color='C4'
    )
    single_line(
        r'$Total$',
        conn.sel(roi='%s-%s' % (roi_1, roi_2), comp='tot'),
        color='C3'
    )

    plt.axvline(0., color='C3', lw=2)
    if legend:
        plt.legend()


def plt_img(name, ax):
    img_intra = mpimg.imread(f'images/{name}.png')
    plt.imshow(img_intra)
    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

# subplot('aINS', 'dlPFC')


### Plot results across conditions

In [None]:
###############################################################################
roi_order = [
    'aINS-aINS', 'dlPFC-dlPFC', 'aINS-dlPFC', 'aINS-lOFC', 'aINS-vmPFC',
    'lOFC-lOFC', 'vmPFC-vmPFC', 'dlPFC-vmPFC', 'dlPFC-lOFC', 'lOFC-vmPFC'
]
###############################################################################

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(
    24, 9), sharex=True, sharey=True)
axs = np.ravel(axs)
plt.subplots_adjust(wspace=0.15)
# fig.suptitle(, fontweight='bold', fontsize=24)

for n_ax, r in enumerate(roi_order):
    plt.sca(axs[n_ax])
    # plot superimposed lines
    roi_1, roi_2 = r.split('-')
    subplot(roi_1, roi_2, legend=False)
    plt.title(r"$\bf{R_{1}=%s; R_{2}=%s}$" % (roi_1, roi_2), fontweight='bold')
    
    # plot significant cluster (if any)
    cl = ii.sel(roi=r)
    if cl.data.any():
        line = np.full((len(cl['times'])), np.nan)
        line[cl.data] = -0.001
        ln, = plt.plot(cl['times'].data, line, color='#FFA500', lw=6)
        ln.set_solid_capstyle('round')

    if n_ax in [0, 5]:
        plt.ylabel('Information (bits)')
    if n_ax > 4:
        plt.xlabel('Times (s)')
plt.xticks([-.5, 0., .5, 1., 1.5])
plt.ylim(-0.005, 0.015)

# create the legend
custom_lines = [
    Line2D([0], [0], color="C0", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C1", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C5", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C4", lw=2, solid_capstyle='round'),
    Line2D([0], [0], color="C3", lw=2, solid_capstyle='round'),
    # Line2D([0], [0], color="C3", lw=6, solid_capstyle='round'),
]
titles = [
    r"$Unique_{R_{1}}$", r"$Unique_{R_{2}}$", r"$Synergy$", r"$Redundancy$",
    r"$Total$"  # r"Significant cluster of II (p<0.05)", 
]
plt.legend(
    custom_lines, titles, ncol=5, bbox_to_anchor=(.77, 0.03), fontsize=20,
    bbox_transform=fig.transFigure,
    title=f"Partial information decomposition for both R/PPE",
    title_fontproperties=dict(weight='bold', size=20)
);

# annotations
height = 1.06
kw_ann = dict(
    horizontalalignment='center', verticalalignment='center',
    arrowprops=dict(arrowstyle="<|-|>", color='C3'), fontsize=18,
    xycoords='figure fraction'
)
plt.gca().annotate(
    '', xy=(0.066, height), xytext=(0.368, height), **kw_ann
)
plt.gca().annotate(
    '', xy=(0.38, height), xytext=(.842, height), **kw_ann
)
plt.figtext(0.188, 1., 'Within ROI interactions', fontsize=22, color='C3',
            fontweight=None)
plt.figtext(0.548, 1., 'Across ROI interactions', fontsize=22, color='C3',
            fontweight=None)

ax = fig.add_axes([.315, .97, .08, .08])
plt_img('links_intra', ax)

ax = fig.add_axes([.7, .97, .08, .08])
plt_img('links_inter', ax)


In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_pid-rppe.png', **cfg_export)


In [None]:
uni