<div class="alert alert-success">
<h1>Context-dependent information transfer</h1>
$TE_{Reward}$ vs. $TE_{Punishment}$
</div>

In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style
from frites.utils import kernel_smoothing

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D

from ipywidgets import interact
from research.study.pblt import plot_nx

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()
plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 21
plt.rcParams['axes.labelsize'] = 22

---
# **I/O**

In [None]:
###############################################################################
# te settings
min_delay = 30
max_delay = 60
biascorrection = True

# stat settings
savgol = 10
n_perm = 200

from_folder = f'conn/te/stats/cond'
###############################################################################

st = Study('PBLT')

args = (
    f"nperm-{n_perm}_", f"savgol-{savgol}",
    f"bias-{int(biascorrection)}", f"{min_delay}-2-{max_delay}"
)

# load significance testing
f = st.search(*args, folder=from_folder)
assert len(f) == 1
dt = xr.load_dataset(f[0])

---
# **Figure**
## Lineplot
### All subplots

In [None]:
# ###############################################################################
# es = 'mi'
# ci_l = 'sem'
# conds = ['pun', 'rew']
# stats = 'contrast'  # {contrast, condition}
# sigma = 0.01
# ###############################################################################

# if stats == 'contrast':
#     pv_conds = ['pun>rew', 'rew>pun']
# elif stats == 'condition':
#     pv_conds = ['pun', 'rew']

# # select effect size and p-values
# sstr = "%.3f" % sigma
# mi = dt[[f"{es}_{c}" for c in conds]].to_array('cond')
# mi_ci = dt[[f"ci_{c}" for c in conds]].to_array('cond').sel(ci=ci_l)
# pv = dt[[f"pv_{c}_{sstr[2::]}" for c in pv_conds]].to_array('cond')

# mi['cond'] = pv['cond'] = mi_ci['cond'] = ['pun', 'rew']
# times = mi['times'].data

# # cluster definition
# cl = xr.full_like(mi, np.nan)
# minmin = mi_ci.data.min()
# delta = (mi_ci.data.max() - minmin) / 100.
# cl.data[0, pv.data[0, ...] < 0.05] = minmin - delta
# cl.data[1, pv.data[1, ...] < 0.05] = minmin - 2 * delta

# # prepare the figure
# fig, axs = plt.subplots(
#     nrows=2, ncols=6, sharex=True, sharey=True, figsize=(16, 6)
# )
# axs = np.ravel(axs)

# # plot the results
# q = 0
# for n_r, r in enumerate(np.unique(mi['roi'].data)):
#     for n_c, c in enumerate(mi['cond'].data):
#         # effect size and p-value selection
#         _es, _cl = mi.sel(roi=r, cond=c), cl.sel(roi=r, cond=c)
#         _ci = mi_ci.sel(roi=r, cond=c)
#         _ci_l, _ci_h = _ci.sel(bound='low'), _ci.sel(bound='high')

#         plt.sca(axs[q])
#         plt.plot(times, _es.data, color=f"C{n_c}")
#         ln, = plt.plot(times, _cl.data, color=f"C{n_c}", lw=5)
#         ln.set_solid_capstyle('round')
#         plt.fill_between(times, _ci_l.data, _ci_h.data, color=f"C{n_c}",
#                          alpha=.3)
#         plt.axvline(0., color='C3')
#         plt.title(str(r).replace('->', r'$\rightarrow$'), fontweight='bold')
#         if q in [0, 6]: plt.ylabel('TE (bits)')
#         if q >= 6: plt.xlabel('Times (s)')

#     q += 1

# plt.gca().set_xticks([0., 0.5, 1., 1.5])
# plt.tight_layout()

# # create the legend
# custom_lines = [Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
#                 Line2D([0], [0], color="C1", lw=6, solid_capstyle='round')]

# pv_conds = [r'$TE_{Pun} > TE_{Rew}$', r'$TE_{Rew} > TE_{Pun}$']
# legend = plt.legend(
#     custom_lines, pv_conds, ncol=2, bbox_to_anchor=(.73, 0.02),
#     fontsize=20, bbox_transform=fig.transFigure, title="Significant clusters of TE (p<0.05)",
#     title_fontproperties=dict(weight='bold', size=18)
# );


### Only the significant subplots

In [None]:
###############################################################################
es = 'mi'
ci_l = 'sem'
conds = ['pun', 'rew']
stats = 'contrast'  # {contrast, condition}
sigma = 0.01
###############################################################################

plt.rcParams['xtick.labelsize'] = 22
plt.rcParams['ytick.labelsize'] = 22
plt.rcParams['axes.titlesize'] = 28
plt.rcParams['axes.labelsize'] = 26

if stats == 'contrast':
    pv_conds = ['pun>rew', 'rew>pun']
elif stats == 'condition':
    pv_conds = ['pun', 'rew']

# select effect size and p-values
sstr = "%.3f" % sigma
mi = dt[[f"{es}_{c}" for c in conds]].to_array('cond')
mi_ci = dt[[f"ci_{c}" for c in conds]].to_array('cond').sel(ci=ci_l)
pv = dt[[f"pv_{c}_{sstr[2::]}" for c in pv_conds]].to_array('cond')

mi['cond'] = pv['cond'] = mi_ci['cond'] = ['pun', 'rew']
times = mi['times'].data

# subselect significant roi only
signi_roi = mi['roi'].data[(pv < 0.05).any(('times', 'cond')).data]
mi, pv, mi_ci = mi.sel(roi=signi_roi), pv.sel(roi=signi_roi), mi_ci.sel(roi=signi_roi)

# cluster definition
cl = xr.full_like(mi, np.nan)
minmin = mi_ci.data.min()
delta = (mi_ci.data.max() - minmin) / 100.
cl.data[0, pv.data[0, ...] < 0.05] = minmin - delta
cl.data[1, pv.data[1, ...] < 0.05] = minmin - 2 * delta

# prepare the figure
fig, axs = plt.subplots(
    nrows=1, ncols=2, sharex=True, sharey=True, figsize=(14, 6)
)
axs = np.ravel(axs)

# plot the results
q = 0
for n_r, r in enumerate(np.unique(mi['roi'].data)):
    for n_c, c in enumerate(mi['cond'].data):
        # effect size and p-value selection
        _es, _cl = mi.sel(roi=r, cond=c), cl.sel(roi=r, cond=c)
        _ci = mi_ci.sel(roi=r, cond=c)
        _ci_l, _ci_h = _ci.sel(bound='low'), _ci.sel(bound='high')

        plt.sca(axs[q])
        plt.plot(times, _es.data, color=f"C{n_c}")
        ln, = plt.plot(times, _cl.data, color=f"C{n_c}", lw=8)
        ln.set_solid_capstyle('round')
        plt.fill_between(times, _ci_l.data, _ci_h.data, color=f"C{n_c}",
                         alpha=.3)
        plt.axvline(0., color='C3')
        plt.title(str(r).replace('->', r'$\rightarrow$'), fontweight='bold')
        plt.xlabel('Times [s]')
        if q == 0: plt.ylabel('TE [bits]')

    q += 1

plt.gca().set_xticks([0., 0.5, 1., 1.5])
plt.tight_layout()
# plt.xlim(mi['times'].data[0], 1.5)

# create the legend
custom_lines = [Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
                Line2D([0], [0], color="C1", lw=6, solid_capstyle='round')]

pv_conds = [r'$TE_{Pun} > TE_{Rew}$', r'$TE_{Rew} > TE_{Pun}$']
legend = plt.legend(
    custom_lines, pv_conds, ncol=2, bbox_to_anchor=(.82, 0.02),
    fontsize=24, bbox_transform=fig.transFigure, title="Significant clusters of TE (p<0.05)",
    title_fontproperties=dict(weight='bold', size=26)
);


## Export the figure

In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_te_context.png', **cfg_export)

## Temporal significant clusters

In [None]:
pv_n = pv.copy()
pv_n.data[pv_n.data >= 0.05] = np.nan
pv_n = pv_n.dropna('roi', how='all')
for c in pv_n['cond'].data:
    for r in pv_n['roi'].data:
        pv_sel = pv_n.sel(cond=c, roi=r)
        if np.any(~np.isnan(pv_sel.data)):
            pv_sel = pv_sel.dropna('times')
            period = [pv_sel['times'].data[0] * 1000, pv_sel['times'].data[-1] * 1000]
            print(f"Cond={c}; ROI={r}; PERIOD=[{period[0]}, {period[-1]}]ms")

In [None]:
###############################################################################
es_rew = 'rew'
es_pun = 'pun'
###############################################################################

# effect size and p-value selection
mi = xr.Dataset({'pun': dt[f'mi_{es_pun}'], 'rew': dt[f'mi_{es_rew}']}).to_array('cond')
pv = xr.Dataset({'pun': dt[f'pv_{es_pun}'], 'rew': dt[f'pv_{es_rew}']}).to_array('cond')

# drop non-significant effects
mi.data[pv.data >= 0.05] = np.nan
mi = mi.mean('times')
maxmax, minmin = np.nanmax(mi.data), np.nanmin(mi.data)
tomax = 8 * np.array([np.nanmax(mi.sel(cond='rew').data) / maxmax, np.nanmax(mi.sel(cond='pun').data) / maxmax])
tomin = np.array([np.nanmin(mi.sel(cond='rew').data) / minmin, np.nanmin(mi.sel(cond='pun').data) / minmin])

fig, axs = plt.subplots(
    nrows=2, ncols=1, sharex=True, sharey=True, figsize=(6, 10)
)

for n_c, cond in enumerate(['rew', 'pun']):
    # color selection
    if cond == 'rew':
        node_color = 'C1'
        edge_cmap = plt.cm.Blues
        title = r'$TE_{rew} > TE_{pun}$'
    elif cond == 'pun':
        node_color = 'C0'
        edge_cmap = plt.cm.Reds
        title = r'$TE_{pun} > TE_{rew}$'
    
    # condition selection
    
    
    plt.sca(axs[n_c])
    plot_nx(mi.sel(cond=cond), node_color='C3', edge_cmap=edge_cmap, tomin=tomin[n_c], tomax=tomax[n_c])
    plt.title(title, fontweight='bold')

plt.tight_layout()