In [None]:
import json
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D

from ipywidgets import interact

set_mpl_style()

with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()
plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 21
plt.rcParams['axes.labelsize'] = 22


---
# **Load the stats**

In [None]:
###############################################################################
model = 'PE'
###############################################################################

st = Study('PBLT')

# load the dataset
# f = st.search('alpha-001', folder='cfi/psd-mi-%s' % model)
f = st.search('center', folder='cfi/psd-mi-%s' % model)
assert len(f) == 1
dt = xr.load_dataset(f[0])


---
# **Make the figure**
## Lineplot only

In [None]:
###############################################################################
sigma = 0.1
p = 0.05
ci = 'sem'
###############################################################################

# sigma
sstr = "%.3f" % sigma

# extract variable
mi = dt[[f'mi_rew_{sstr[2:]}', f'mi_pun_{sstr[2:]}']].to_array('cond')
tv = dt[[f'tv_rew_{sstr[2:]}', f'tv_pun_{sstr[2:]}']].to_array('cond')
pv = dt[[f'pv_rew_{sstr[2:]}', f'pv_pun_{sstr[2:]}']].to_array('cond')
ci = dt[[f'ci_rew_{sstr[2:]}', f'ci_pun_{sstr[2:]}']].to_array('cond').sel(ci=ci)
freqs = mi['freqs'].data
# mi = tv

# rename condition
mi['cond'] = pv['cond'] = ci['cond'] = tv['cond'] = ['rew', 'pun']

# identify significant clusters
cl = xr.full_like(mi, np.nan)
minmin = ci.data.min()
delta = (ci.data.max() - minmin) / 20.
cl.data[0, pv.data[0, ...] < p] = minmin - delta
cl.data[1, pv.data[1, ...] < p] = minmin - 2 * delta

# plot the results
fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(20, 5), sharex=True, sharey=True)
axs = np.ravel(axs)

for n_r, r in enumerate(['aINS', 'dlPFC', 'lOFC', 'vmPFC']):
    for n_c, c in enumerate(['pun', 'rew']):
        plt.sca(axs[n_r])
        plt.plot(freqs, mi.sel(roi=r, cond=c), color=f'C{n_c}')
        plt.fill_between(
            freqs, ci.sel(bound='low', cond=c, roi=r).data,
            ci.sel(bound='high', cond=c, roi=r).data, color=f'C{n_c}',
            alpha=.2, zorder=-1
        )
        ln, = plt.plot(freqs, cl.sel(cond=c, roi=r).data,
                       color=f'C{n_c}', lw=6)
        ln.set_solid_capstyle('round')
        
        plt.title(f"{r}", fontweight='bold')
        plt.xlabel('Frequency [Hz]')
        if n_r == 0: plt.ylabel('MI [bits]')

# plt.yticks([0., 0.005, 0.01, 0.015, 0.02])

# create the legend
custom_lines = [
    Line2D([0], [0], color="C1", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C0", lw=6, solid_capstyle='round')
]
titles = [r"$MI_{RPE}$", r"$MI_{PPE}$"]
plt.legend(
    custom_lines, titles, ncol=3, bbox_to_anchor=(.65, -0.03), fontsize=20,
    bbox_transform=fig.transFigure, title="Significant cluster of MI (p<0.05)",
    title_fontproperties=dict(weight='bold', size=20)
);


## Lineplot + MI proportion

In [None]:
###############################################################################
sigma = 0.1
p = 0.05
ci = 'sem'
nb_kw = dict(size=27, weight='bold')
###############################################################################

# sigma
sstr = "%.3f" % sigma

# extract variable
mi = dt[[f'mi_rew_{sstr[2:]}', f'mi_pun_{sstr[2:]}']].to_array('cond')
tv = dt[[f'tv_rew_{sstr[2:]}', f'tv_pun_{sstr[2:]}']].to_array('cond')
pv = dt[[f'pv_rew_{sstr[2:]}', f'pv_pun_{sstr[2:]}']].to_array('cond')
ci = dt[[f'ci_rew_{sstr[2:]}', f'ci_pun_{sstr[2:]}']].to_array('cond').sel(ci=ci)
freqs = mi['freqs'].data

# rename condition
mi['cond'] = pv['cond'] = ci['cond'] = tv['cond'] = ['rew', 'pun']

# identify significant clusters
cl = xr.full_like(mi, np.nan)
minmin = ci.data.min()
delta = (ci.data.max() - minmin) / 20.
cl.data[0, pv.data[0, ...] < p] = minmin - delta
cl.data[1, pv.data[1, ...] < p] = minmin - 2 * delta

# compute mi proportion
mi_gp = np.abs(mi).groupby_bins('freqs', [50, 100, 150, 200]).mean()
mi_gp = 100 * mi_gp / mi_gp.sum('freqs_bins')
mi_gp['freqs_bins'] = ['[50, 100]', '[100, 150]', '[150, 200]']
mi_gp = mi_gp.rename(freqs_bins='freqs')

# plot the results
fig, axs = plt.subplots(
    nrows=2, ncols=4, figsize=(22, 11), sharex='row', sharey='row'
)
plt.subplots_adjust(wspace=0.1)

for n_r, r in enumerate(['aINS', 'dlPFC', 'lOFC', 'vmPFC']):
    # lineplot
    for n_c, c in enumerate(['pun', 'rew']):
        plt.sca(axs[0, n_r])
        plt.plot(freqs, mi.sel(roi=r, cond=c), color=f'C{n_c}')
        plt.fill_between(
            freqs, ci.sel(bound='low', cond=c, roi=r).data,
            ci.sel(bound='high', cond=c, roi=r).data, color=f'C{n_c}',
            alpha=.2, zorder=-1
        )
        ln, = plt.plot(freqs, cl.sel(cond=c, roi=r).data,
                       color=f'C{n_c}', lw=8)
        ln.set_solid_capstyle('round')
        
        plt.title(f"{r}", fontweight='bold')
        plt.xlabel('Frequency [Hz]')
        if n_r == 0: plt.ylabel('MI [bits]')
    
    # proportion
    plt.sca(axs[1, n_r])
    g = sns.barplot(
        data=mi_gp.sel(roi=r).to_dataframe('MI [%]').reset_index(),
        x='freqs', y='MI [%]', hue='cond', hue_order=['rew', 'pun'],
        palette=['C1', 'C0']
    )
    g.legend_.remove()
    if n_r > 0: plt.ylabel('')
    plt.xlabel('Frequency [Hz]')
        
plt.sca(axs[0, 0])
plt.yticks([0., 0.005, 0.01, 0.015, 0.02])

# create the legend
custom_lines = [
    Line2D([0], [0], color="C1", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C0", lw=6, solid_capstyle='round')
]
titles = [r"$MI_{RPE}$", r"$MI_{PPE}$"]
plt.legend(
    custom_lines, titles, ncol=3, bbox_to_anchor=(.65, 0.04), fontsize=20,
    bbox_transform=fig.transFigure, title="Significant cluster of MI (p<0.05)",
    title_fontproperties=dict(weight='bold', size=20)
);

# annotations
ax = axs[0, 0]
ax.text(*tuple([-.3, 1.05]), 'A', transform=ax.transAxes, **nb_kw)
ax = axs[1, 0]
ax.text(*tuple([-.3, 1.05]), 'B', transform=ax.transAxes, **nb_kw);


## Export the figure

In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/supp_gamma_mi.png', **cfg_export)
