In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study
from frites import set_mpl_style
from frites.conn import conn_links
from research.study.pblt import get_model, get_anat_table
from visbrain.utils import normalize

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
import seaborn as sns

from ipywidgets import interact

import json
with open("../config.json", 'r') as f:
    cfg = json.load(f)

set_mpl_style('frites')

plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 25
plt.rcParams['axes.labelsize'] = 24

# **$II(\gamma; R/PPE)$**
---
## I/O

In [None]:
from_folder = "conn/piid/lga-PE/ii/cond_ss"

st = Study('PBLT')

dt = []
for s in range(0, 20):
    _concat = []

    # --------------------------------- INTRA ---------------------------------
    f = st.search(f"subject-{s}.nc", "relation-intra", folder=from_folder)
    if len(f):
        _concat.append(xr.load_dataset(f[0]))

    # --------------------------------- INTER ---------------------------------
    f = st.search(f"subject-{s}.nc", "relation-inter", folder=from_folder)
    if len(f):
        _concat.append(xr.load_dataset(f[0]))
    if not len(_concat): continue

    # concatenate everything
    _dt = xr.concat(_concat, dim='roi')

    # clean roi dimension
    _dt['roi'] = [r.split('/')[0] for r in _dt['roi'].data]

    dt.append(_dt)

## Compute proportions

In [None]:
###############################################################################
sigma = "100"
###############################################################################

#
def compute_stat(pv, mi=None):
    pv_red = pv[mi < 0].any()
    pv_syn = pv[mi > 0].any()
    return np.stack([pv_red, pv_syn])

# compute stat
ii_rew, ii_pun = [], []
for s in range(len(dt)):
    dt_s = dt[s]

    # rewarding condition
    mi_rew = dt_s["mi_rew"]
    pv_rew = dt_s[f"pv_rew_{sigma}"] < 0.05
    if len(mi_rew['roi']) > 1:
        mi_gp, pv_gp = mi_rew.groupby('roi'), pv_rew.groupby('roi')
    else:
        mi_gp, pv_gp = mi_rew, pv_rew
    _ii_rew = xr.apply_ufunc(
        compute_stat, pv_gp, mi_gp,
        input_core_dims=[['roi'], ['roi']], vectorize=True,
        output_core_dims=[['itype']]
    )
    if _ii_rew.ndim == 2:
        _ii_rew = _ii_rew.expand_dims('roi', 1)
        _ii_rew['roi'] = dt_s['roi']

    # punishing condition
    mi_pun = dt_s["mi_pun"]
    pv_pun = dt_s[f"pv_pun_{sigma}"] < 0.05
    if len(mi_pun['roi']) > 1:
        mi_gp, pv_gp = mi_pun.groupby('roi'), pv_pun.groupby('roi')
    else:
        mi_gp, pv_gp = mi_pun, pv_pun
    _ii_pun = xr.apply_ufunc(
        compute_stat, pv_gp, mi_gp,
        input_core_dims=[['roi'], ['roi']], vectorize=True,
        output_core_dims=[['itype']]
    )
    if _ii_pun.ndim == 2:
        _ii_pun = _ii_pun.expand_dims('roi', 1)
        _ii_pun['roi'] = dt_s['roi']

    ii_rew.append(_ii_rew)
    ii_pun.append(_ii_pun)

ii_rew = xr.concat(ii_rew, dim='roi')
ii_pun = xr.concat(ii_pun, dim='roi')
ii_rew['itype'] = ii_pun['itype'] = ['Redundancy', 'Synergy']


def compute_prop(x):
    return 100 * x.sum('roi') / len(x['roi'])


# compute proportions
ii_rew_prop = ii_rew.groupby('roi').apply(compute_prop)
ii_pun_prop = ii_pun.groupby('roi').apply(compute_prop)

# compute number of subjects


def compute_nsubjects(x):
    x_m = x.isel(times=0, roi=[0], itype=0).drop("times").drop("itype")
    x_m.data = [len(x['roi'])]
    return x_m


n_subjects = ii_rew.groupby('roi').apply(compute_nsubjects)

In [None]:
ii_rew_prop.sel(roi="vmPFC-vmPFC").max('times')

## Lineplot

In [None]:
def plt_img(name, ax):
    img_intra = mpimg.imread(f'../images/{name}.png')
    plt.imshow(img_intra)
    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
###############################################################################
roi_order = [
    "aINS-aINS", "dlPFC-dlPFC", "aINS-dlPFC", "aINS-lOFC", "aINS-vmPFC",
    "lOFC-lOFC", "vmPFC-vmPFC", "dlPFC-vmPFC", "dlPFC-lOFC", "lOFC-vmPFC"
]
###############################################################################

fig, axs = plt.subplots(
    2, 5, figsize=(24, 8), sharey=True, sharex=True,
    gridspec_kw=dict(hspace=.35, wspace=0.25, top=0.96)
)
axs = np.ravel(axs)

for n_r, r in enumerate(roi_order):
    plt.sca(axs[n_r])
    plt.title(
        f"{r}\n" + r"$(n_{Subjects}=%i)$" % n_subjects.sel(roi=r).data,
        fontweight='bold'
    )
    plt.axvline(0., color='C3')

    for n_c, c in enumerate(["rew", "pun"]):
        color = "C1" if c == "rew" else "C0"
        ii = ii_rew_prop if c == "rew" else ii_pun_prop
        for n_i, itype in enumerate(["Redundancy", "Synergy"]):
            if itype == "Redundancy":
                kw_line = dict(lw=2)
                fact = -1.
            else:
                kw_line = dict(lw=2, linestyle='--')
                fact = 1.

            plt.plot(
                ii['times'].data, fact * ii.sel(roi=r, itype=itype).data,
                color=color, **kw_line
            )
    if n_r >= 5:
        plt.xlabel("Times [s]")
plt.ylim(-70, 70)
plt.xlim(-.5, 1.5)
plt.xticks([-.5, 0., 0.5, 1., 1.5])
plt.yticks([-60, -30, 0, 30, 60])
axs[0].set_yticklabels([60, 30, 0, 30, 60])

def add_text(ax):
    plt.sca(ax)
    kw_text = dict(color="#333333", va='center',
                ha="center", rotation=90, fontsize=20)
    plt.text(-1., -45., "Redundant\n[%]", **kw_text)
    plt.text(-1., 45., "Synergistic\n[%]", **kw_text)

add_text(axs[0])
add_text(axs[5])

# create the legend
custom_lines = [
    Line2D([0], [0], color="C1", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
]
pv_conds = [r"$II_{RPE}$", r"$II_{PPE}$"]

legend = plt.legend(
    custom_lines, pv_conds, ncol=2, bbox_to_anchor=(.6, .03),
    fontsize=24, bbox_transform=fig.transFigure,
    title_fontproperties=dict(weight='bold', size=26)
)

# annotations
height = 1.17
kw_ann = dict(
    horizontalalignment='center', verticalalignment='center',
    arrowprops=dict(arrowstyle="<|-|>", color='C3'), fontsize=18,
    xycoords='figure fraction'
)
plt.gca().annotate(
    '', xy=(0.062, height), xytext=(0.355, height), **kw_ann
)
plt.gca().annotate(
    '', xy=(0.387, height), xytext=(.84, height), **kw_ann
)
kw_txt = dict(fontsize=22, color='C3', fontweight=None)
plt.figtext(
    0.16, 1.15, 'Within-region interactions', **kw_txt
)
plt.figtext(0.52, 1.15, 'Between-regions interactions', **kw_txt)

ax = fig.add_axes([.31, 1.115, .1, .1])
plt_img('links_intra_b&w', ax)

ax = fig.add_axes([.72, 1.115, .1, .1])
plt_img('links_inter_b&w', ax)


In [None]:
save_to = cfg['export']['review']
cfg_export = cfg['export']['cfg']

fig.savefig(f'{save_to}/ii_rppe_ss.png', **cfg_export)

# **I(\gamma; PE)**
---
## I/O

In [None]:
from_folder = "conn/piid/lga-PE/ii/full_ss"

st = Study('PBLT')

dt = []
for s in range(0, 20):
    _concat = []

    # --------------------------------- INTRA ---------------------------------
    f = st.search(f"subject-{s}.nc", "relation-intra", folder=from_folder)
    if len(f):
        _concat.append(xr.load_dataset(f[0]))

    # --------------------------------- INTER ---------------------------------
    f = st.search(f"subject-{s}.nc", "relation-inter", folder=from_folder)
    if len(f):
        _concat.append(xr.load_dataset(f[0]))
    if not len(_concat):
        continue

    # concatenate everything
    _dt = xr.concat(_concat, dim='roi')

    # clean roi dimension
    _dt['roi'] = [r.split('/')[0] for r in _dt['roi'].data]

    dt.append(_dt)

# Compute proportions

In [None]:
###############################################################################
sigma = "100"
###############################################################################

#


def compute_stat(pv, mi=None):
    pv_red = pv[mi < 0].any()
    pv_syn = pv[mi > 0].any()
    return np.stack([pv_red, pv_syn])


# compute stat
ii = []
for s in range(len(dt)):
    dt_s = dt[s]

    # rewarding condition
    mi = dt_s["mi_"]
    pv = dt_s[f"pv__{sigma}"] < 0.05
    if len(mi['roi']) > 1:
        mi_gp, pv_gp = mi.groupby('roi'), pv.groupby('roi')
    else:
        mi_gp, pv_gp = mi, pv
    _ii = xr.apply_ufunc(
        compute_stat, pv_gp, mi_gp,
        input_core_dims=[['roi'], ['roi']], vectorize=True,
        output_core_dims=[['itype']]
    )
    if _ii.ndim == 2:
        _ii = _ii.expand_dims('roi', 1)
        _ii['roi'] = dt_s['roi']

    ii.append(_ii)

ii = xr.concat(ii, dim='roi')
ii["itype"] = ['Redundancy', 'Synergy']

def compute_prop(x):
    return 100 * x.sum('roi') / len(x['roi'])

# compute proportions
ii_prop = ii.groupby('roi').apply(compute_prop)

# compute number of subjects


def compute_nsubjects(x):
    x_m = x.isel(times=0, roi=[0], itype=0).drop("times").drop("itype")
    x_m.data = [len(x['roi'])]
    return x_m


n_subjects = ii.groupby('roi').apply(compute_nsubjects)

In [None]:
r = "dlPFC-vmPFC"

fig = plt.figure(figsize=(10, 7))

plt.title(
    f"{r}\n" + r"$(n_{Subjects}=%i)$" % n_subjects.sel(roi=r).data,
    fontweight='bold'
)
plt.axvline(0., color='C3')

for n_i, itype in enumerate(["Redundancy", "Synergy"]):
    if itype == "Redundancy":
        lw = 4
        fact = -1.
    else:
        lw = 2
        fact = 1.

    plt.plot(
        ii['times'].data, fact * ii_prop.sel(roi=r, itype=itype).data,
        color="C3", lw=lw
    )

plt.xlabel("Times [s]")
plt.ylim(-70, 70)
plt.xlim(-.5, 1.5)
plt.xticks([-.5, 0., 0.5, 1., 1.5])
plt.yticks([-60, -30, 0, 30, 60])
plt.gca().set_yticklabels([60, 30, 0, 30, 60])


def add_text(ax):
    plt.sca(ax)
    kw_text = dict(color="#333333", va='center',
                   ha="center", rotation=90, fontsize=20)
    plt.text(-.75, -40., "Redundant [%]", **kw_text)
    plt.text(-.75, 40., "Synergistic [%]", **kw_text)


add_text(plt.gca())

In [None]:
# ii_prop.sel(roi=r, itype=itype).plot()

# ii.sel(roi=r, itype="Synergy").plot(x='times', hue="roi")

# plt.plot(ii.sel(roi=r, itype="Synergy").data[:, 3])

1 / 4