In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style
from frites.stats import confidence_interval

import matplotlib.pyplot as plt
import seaborn as sns

from ipywidgets import interact

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()

plt.rcParams['xtick.labelsize'] = 'xx-large'
plt.rcParams['ytick.labelsize'] = 'xx-large'
plt.rcParams['axes.titlesize'] = 23
plt.rcParams['axes.labelsize'] = 22

In [None]:
###############################################################################
band = 'lga'
from_folder = f'pow/{band}/st-av'
###############################################################################

st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())

pow = []
for s in subjects:
    f = st.search(s, folder=from_folder)
    if not len(f): continue
    
    pow.append(xr.load_dataarray(f[0]))

pow = xr.concat(pow, 'roi').rename(trials='Condition').sel(Condition=['rew', 'pun'])

In [None]:
###############################################################################
cil = 'sem'
###############################################################################

titles = {
    'rew': 'Rewarding',
    'pun': 'Punishing',
}

# groupby roi
pow_gp = pow.groupby('roi')

# compute mean power
pow_m = pow_gp.mean()

# compute ci
def compute_ci(x):
    return confidence_interval(
        x, axis='roi', random_state=0, cis=cil).squeeze()
ci = pow_gp.apply(compute_ci)

# plot the results
fg = pow_m.plot(x='times', col='Condition', hue='roi', size=6)
fig = plt.gcf()
_ = [ax.axvline(0., color='C3') for ax in np.ravel(fg.axes)]

for n_c, cond in enumerate(pow['Condition'].data):
    plt.sca(np.ravel(fg.axes)[n_c])
    plt.xlabel('Times (s)')
    if n_c == 0:
        plt.ylabel("Relative gamma power")
    plt.title(f"{titles[cond]} condition", fontweight='bold')
    ci_c = ci.sel(Condition=cond)
    
    for n_r, r in enumerate(ci_c['roi'].data):
        ci_cr = ci_c.sel(roi=r)
        plt.fill_between(
            ci_cr['times'].data,
            ci_cr.sel(bound='low').data, ci_cr.sel(bound='high').data,
            color=f'C{n_r}', alpha=.1, zorder=61
        )


In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_gamma_pow.png', **cfg_export)