In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style

import matplotlib.pyplot as plt
import seaborn as sns

from ipywidgets import interact

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()

---
# **Load the cross-frequency redundancy**

In [None]:
###############################################################################
sigma = 0.001
from_folder = 'cfi/psd-PE'
###############################################################################

st = Study('PBLT')

sigma = str(float(sigma))[2:]

# load reward condition
f = st.search('-rew_', f'sigma-{sigma}.', folder=from_folder)
assert len(f) == 1
dt_rew = xr.load_dataset(f[0]).to_array('var')

# load punishment condition
f = st.search('-pun_', f'sigma-{sigma}.', folder=from_folder)
assert len(f) == 1
dt_pun = xr.load_dataset(f[0]).to_array('var')

# load merged rew / pun
f = st.search('-rewpun_', f'sigma-{sigma}.', folder=from_folder)
assert len(f) == 1
dt_rp = xr.load_dataset(f[0]).to_array('var')

# merge conditions
dt = xr.Dataset({'rew': dt_rew, 'pun': dt_pun}).to_array('cond')

---
# **Create the figure**

In [None]:
###############################################################################
es = 'mi'
kw_text = dict(rotation=90, fontsize=20, fontweight='bold', va='center',
               color='#555555')
###############################################################################

fig, axs = plt.subplots(
    nrows=2, ncols=4, sharex=True, sharey=True, figsize=(15.5, 7)
)
# fig.suptitle(f"Text ({})", fontsize=20, fontweight='bold')

for n_r, r in enumerate(np.sort(dt['roi'].data)):
    dt_r = dt.sel(var=es, roi=r)
    vmin, vmax = np.nanpercentile(dt_r.data, 5), np.nanpercentile(dt_r.data, 99)
    # vmin, vmax = 2., 6.

    for n_c, c in enumerate(['rew', 'pun']):
        plt.sca(axs[n_c, n_r])
        _es = dt_r.sel(cond=c)
        _pv = dt.sel(var='pv', cond=c, roi=r)
        # _es.data[_pv.data >= 0.05] = np.nan
        
        df_es = _es.to_pandas()
        df_pv = _pv.to_pandas()
        xvec = df_es.columns
        yvec = df_es.index
        
        plt.grid(False)
        im = plt.pcolormesh(df_es.columns, df_es.index, df_es.values,
                            cmap='Spectral_r', vmin=vmin, vmax=vmax)

        # if df_pv.values.min() <= 0.05:
        #     plt.contour(df_pv.values, origin='upper', extent=[xvec[0], xvec[-1], yvec[-1], yvec[0]],
        #                 levels=[0.05], colors=['k'])
        
        if n_c == 0: plt.title(r, fontweight='bold', color='#555555')
        if n_c == 1: plt.xlabel('Frequency (Hz)')
        if n_r == 0: plt.ylabel('Frequency (Hz)')
        if (n_c == 0) and (n_r == 0):
            plt.figtext(-0.02, .75, 'RPE', **kw_text)
        if (n_c == 0) and (n_r == 0):
            plt.figtext(-0.02, .3, 'PPE', **kw_text)
        
        for v in [100, 150]:
            plt.axvline(v, lw=2, color='w', linestyle='--')
            plt.axhline(v, lw=2, color='w', linestyle='--')
plt.tight_layout()

# add the colorbar
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.82, 0.3, 0.025, 0.4])
cbar = fig.colorbar(im, cax=cbar_ax)
# cbar.set_label('T-values')
# cbar.ax.set_yticks(np.arange(vmin, vmax + 1).astype(int));

# **Export the figure**

In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/supp_gamma_red.png', **cfg_export)

In [None]:
str(int(1000 * 0.01))

In [None]:
(dt_rew.sel(var='pv') < 0.05).plot(x='f_start', y='f_end', col='roi')