<div class="alert alert-success">
<h1>Information transfer delays and $TE_{Rew}$, $TE_{Pun}$ networks</h1>
<hr>
<p>
1. TE max according to delays
2. Significant TE networks during the Reward and punishment conditions
</p>
</div>

In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from scipy.stats import ttest_1samp, ttest_ind

from pathta import Study

from frites import set_mpl_style
from frites.stats import confidence_interval

import matplotlib.pyplot as plt
import seaborn as sns
from research.study.pblt import plot_nx

import string
import json
with open("config.json", 'r') as f:
    config = json.load(f)
alphabet = string.ascii_uppercase

set_mpl_style()
plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 23
plt.rcParams['axes.labelsize'] = 22


---
# **I/O**
## Load delays

In [None]:
###############################################################################
sfreq = 256.
from_folder = f'conn/te/delays'
###############################################################################

st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())

tedelays = []
for s in subjects:
    _delays = {}
    for cond in ['pun', 'rew']:
        f = st.search(s, f'-{cond}_', folder=from_folder)
        if not len(f) == 1: continue
        __delays = xr.load_dataarray(f[0])
        _delays[cond] = __delays
        del __delays
    if len(_delays):
        tedelays.append(xr.Dataset(_delays).to_array('cond'))
        del _delays

tedelays = xr.concat(tedelays, 'roi')
tedelays['delays'] = 1000 * tedelays['delays'].data / sfreq

---
# **Supplementary figure**
## Delays across regions

In [None]:
###############################################################################
terange = [30, 60]
text_color = 'C5'
ci = '95'
n_boots = 1000
###############################################################################

# compute mean over regions
temean = tedelays.mean('roi')

# dataframe conversion
df_delays = temean.to_dataframe('TE').reset_index().rename(columns={
    'cond': 'Condition'
})
df_delays['Condition'] = df_delays['Condition'].replace({
    'rew': 'Reward',
    'pun': 'Punishment'
})

# compute confidence interval
cis = confidence_interval(
    tedelays, axis='roi', cis=['sem', 95],  n_boots=n_boots, random_state=0,
).sel(ci=ci)
yrange = [0., 1.01 * cis.data.max()]

# make the figure
fig = plt.figure(figsize=(8, 7))
sns.lineplot(
    data=df_delays, x='delays', y='TE', hue='Condition',
    hue_order=['Reward', 'Punishment'], palette=['C1', 'C0']
)
plt.xlabel('Delays [ms]')
# plt.ylabel(r'$TE_{max}$ (bits)')
plt.ylabel('TE [bits]')
plt.xticks([0, 50, 100, 150, 200, 250, 300, 350])
plt.xlim(0., 350)
plt.ylim(yrange)

# better legend
sns.move_legend(plt.gca(), "upper right", bbox_to_anchor=(1.1, 1.05), frameon=True,
                title_fontproperties=dict(weight='bold', size=18))

# plot confidence interval
plt.fill_between(
    cis['delays'].data, cis.sel(bound='low', cond='rew').data,
    cis.sel(bound='high', cond='rew').data, color='C1', alpha=.1
)
plt.fill_between(
    cis['delays'].data, cis.sel(bound='low', cond='pun').data,
    cis.sel(bound='high', cond='pun').data, color='C0', alpha=.1
)

# draw period of interest
drange = 1000 * np.array(terange) / sfreq
drange = np.array([116, 236])
mdelay = drange.mean()

plt.fill_betweenx(
    yrange, drange[0], drange[1], zorder=+1000, color=text_color, alpha=.1
)
plt.axvline(mdelay, linestyle='--', color=text_color)
kw_txt = dict(fontsize=22, color=text_color, fontweight='bold')
plt.text(mdelay + 5, 0.00015, f"{int(np.round(mdelay))}ms", **kw_txt)
plt.text(drange[0] - 5, 0.00001, f"{int(np.round(drange[0]))}ms", ha='right',
         **kw_txt)
plt.text(drange[1] + 5, 0.00001, f"{int(np.round(drange[1]))}ms", ha='left',
         **kw_txt);


In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/supp_te_delays.png', **cfg_export)

## Complete figure with stats
### Load the stats

In [None]:
###############################################################################
savgol = 10
n_perm = 200
sigma = 0.01
biascorrection = True

from_folder = f'conn/te/stats/cond'
###############################################################################

st = Study('PBLT')

args = (f"nperm-{n_perm}_", f"bias-{int(biascorrection)}",
        f"savgol-{savgol}", "30")
sstr = "%.3f" % sigma

# load significance testing
f = st.search(*args, folder=from_folder)
assert len(f) == 1
dt = xr.load_dataset(f[0])
dt = xr.Dataset({
    'mi_rew': dt['mi_rew'],
    'pv_rew': dt[f'pv_rew_{sstr[2::]}'],
    'mi_pun': dt['mi_pun'],
    'pv_pun': dt[f'pv_pun_{sstr[2::]}']
})


### Make the figure

In [None]:
###############################################################################
# lineplot settings
n_boot = 20
delta = 50  # delta on delays

# text properties
kw_text = dict(size=30, weight='bold')
###############################################################################

# -----------------------------------------------------------------------------
#                                PREPROCESSING
# -----------------------------------------------------------------------------
# delays dataframe
df_delays = tedelays.to_dataframe('TE').reset_index().rename(columns={
    'cond': 'Condition'
})
df_delays['Condition'] = df_delays['Condition'].replace({
    'rew': 'Reward',
    'pun': 'Punishment'
})

delays = tedelays['delays'].data
mdelay = tedelays['delays'][tedelays.mean(['cond', 'roi']).argmax()].data
drange = [mdelay - delta, mdelay + delta]

# -----------------------------------------------------------------------------
# define the figure
fig, axs = plt.subplots(
    nrows=1, ncols=3, sharex=False, sharey=False, figsize=(17, 6)
)
axs = np.ravel(axs)
# fig.suptitle(f"Text ({})", fontsize=20, fontweight='bold')

# -----------------------------------------------------------------------------
#                                  DELAYS
# -----------------------------------------------------------------------------
plt.sca(axs[0])
sns.lineplot(
    data=df_delays, x='delays', y='TE', hue='Condition', n_boot=n_boot,
    hue_order=['Reward', 'Punishment'], palette=['C1', 'C0'], legend=True
)
plt.xlabel('Delays (ms)')
plt.ylabel(r'$TE_{max}$ (bits)')
plt.autoscale(tight=True)
axs[0].text(-0.3, 1.05, alphabet[0], transform=axs[0].transAxes, **kw_text)
sns.move_legend(axs[0], "upper right", bbox_to_anchor=(1.1, 1.1), frameon=True,
                title_fontproperties=dict(weight='bold', size=18))
plt.xticks([0, 50, 100, 150, 200, 250, 300, 350])

yrange = np.linspace(0, 0.0007, len(delays))
plt.fill_betweenx(
    yrange, drange[0], drange[1], zorder=-1000, color='C3', alpha=.1
)
plt.axvline(mdelay, linestyle='--', color='C3')
plt.text(mdelay + 5, 0.00001, f"{int(np.round(mdelay))}ms", fontsize=20,
         color='C3')

# -----------------------------------------------------------------------------
#                                  NETWORKS
# -----------------------------------------------------------------------------
# text
axs[1].text(0., 1.05, alphabet[1], transform=axs[1].transAxes, **kw_text)

# effect size and p-value selection
mi = xr.Dataset({'pun': dt[f'mi_pun'], 'rew': dt[f'mi_rew']}).to_array('cond')
pv = xr.Dataset({'pun': dt[f'pv_pun'], 'rew': dt[f'pv_rew']}).to_array('cond')

# drop non-significant effects
mi.data[pv.data >= 0.05] = np.nan
mi = mi.mean('times')
maxmax, minmin = np.nanmax(mi.data), np.nanmin(mi.data)
tomax = 8 * np.array([np.nanmax(mi.sel(cond='rew').data) / maxmax, np.nanmax(mi.sel(cond='pun').data) / maxmax])
tomin = np.array([np.nanmin(mi.sel(cond='rew').data) / minmin, np.nanmin(mi.sel(cond='pun').data) / minmin])

for n_c, cond in enumerate(['rew', 'pun']):
    # color selection
    if cond == 'rew':
        edge_cmap = plt.cm.Blues
        title = r'$TE_{rew}$'
    elif cond == 'pun':
        node_color = 'C0'
        edge_cmap = plt.cm.Reds
        title = r'$TE_{pun}$'
    
    # condition selection
    plt.sca(axs[n_c + 1])
    plot_nx(mi.sel(cond=cond), node_color='C3', edge_cmap=edge_cmap,
            tomin=tomin[n_c], tomax=tomax[n_c])
    plt.title(title, fontweight='bold')
    # plt.xlim(-.5, 1.5)

plt.tight_layout()

In [None]:
tedelays.groupby('roi').mean().plot(x='delays', hue='cond', col='roi', col_wrap=6)

In [None]:
confidence_interval?