In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study
from frites import set_mpl_style
from frites.conn import conn_links
from research.study.pblt import get_model, get_anat_table
from visbrain.utils import normalize

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
import seaborn as sns

from ipywidgets import interact

import json
with open("../config.json", 'r') as f:
    cfg = json.load(f)

set_mpl_style('frites')

plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 25
plt.rcParams['axes.labelsize'] = 24

# **Anatomy**
---

In [None]:
###############################################################################
use_roi = ['aINS', 'dlPFC', 'vmPFC', 'lOFC']
###############################################################################

st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())

# ------------------------------- GROUP ANATOMY -------------------------------

anat = []
for n_s, s in enumerate(subjects):
    # load the bad channels
    bad_ch = st.search(s, folder='bad_channels', load=True)['ch_names']

    # load the anatomy
    _df = get_anat_table(s)
    keep = [c not in bad_ch for c in _df['contact'].values]
    _df = _df.iloc[keep, :]

    # get hemispheres
    __df = get_anat_table(s, lr=True).iloc[keep, :]
    _df['Hemi'] = [r[0] for r in __df["ma_checked"].values]

    # keep only the important brain regions
    keep = [r in use_roi for r in _df['ma_checked'].values]
    _df = _df.iloc[keep, :].loc[:, ['ma_checked', 'contact', 'Hemi']]
    if not len(_df):
        continue

    _df['subject'] = s
    _df['subject #'] = n_s + 1

    anat.append(_df)

# concatenate subjects
anat = pd.concat(anat).reset_index(drop=True)

# -------------------------------- # SUBJECTS ---------------------------------

# number of subject per subject and brain region
anat_suj = anat.groupby([
    'subject', 'subject #', 'ma_checked'
]).count().reset_index()

# number of subjects per brain region


def prop_unique_suj(df):
    """Compute the proportion of unique subject per roi."""
    n_suj = len(np.unique(df['subject']))
    return pd.DataFrame({
        '# subjects': [n_suj],
        'ma_checked': df['ma_checked'].values[0],
    })


anat_usuj = anat_suj.groupby('ma_checked').apply(
    prop_unique_suj
).reset_index(drop=True)

# **Load conjunction results**
---

In [None]:
###############################################################################
model = 'PE'
band = 'lga'
from_folder = f'mi/group/{model}'
###############################################################################

st = Study('PBLT')

# load significance testing
f = st.search('savgol-10.nc', band, folder=from_folder)
assert len(f) == 1
dt = xr.load_dataset(f[0])

# load conjunction
f = st.search('conj', band, folder=from_folder)
assert len(f) == 1
conj = xr.load_dataset(f[0]).to_array('cond').sel(cond=['pun', 'rew'])

model = model.replace('_', ' | ')

# get unique subject for conjunction
conj_suj = conj.copy()
anat_sorted = anat.sort_values(
    by=['ma_checked', 'subject #']
).set_index('ma_checked').loc[
    ['dlPFC', 'aINS', 'vmPFC', 'lOFC']
].reset_index()
np.testing.assert_array_equal(
    conj_suj['roi'].data, anat_sorted['ma_checked'].values
)


def compute_prop_suj(x):
    """Compute the proportion of unique subject per roi."""
    # select roi in the anat
    anat_r = anat_sorted.set_index('ma_checked').loc[x['roi'].values[0]]
    assert len(anat_r) == len(x['roi'])

    # set subject number as roi coordinate
    x['roi'] = anat_r['subject #'].values
    n_suj = len(np.unique(anat_r['subject #'].values))

    # compute the proportion of unique subject
    prop = 100 * (x < 0.05).groupby('roi').any().sum('roi') / n_suj

    return prop


prop = conj_suj.groupby('roi').apply(compute_prop_suj)

# **Compute proportions**
---

In [None]:
def compute_prop_suj(x):
    """Compute the proportion of unique subject per roi."""
    # select roi in the anat
    anat_r = anat_sorted.set_index('ma_checked').loc[x['roi'].values[0]]
    assert len(anat_r) == len(x['roi'])

    # set subject number as roi coordinate
    x['roi'] = anat_r['subject #'].values
    n_suj = len(np.unique(anat_r['subject #'].values))

    # compute the proportion of unique subject
    prop = (x < 0.05).groupby('roi').any()

    # split
    is_rew, is_pun = prop.sel(cond='rew'), prop.sel(cond='pun')
    is_rew_only = 100 * xr.concat(
        (is_rew, ~is_pun), 'cond').all('cond').sum('roi') / n_suj
    is_pun_only = 100 * xr.concat(
        (~is_rew, is_pun), 'cond').all('cond').sum('roi') / n_suj
    is_rp = 100 * xr.concat(
        (is_rew, is_pun), 'cond').all('cond').sum('roi') / n_suj
    is_none = 100 * xr.concat(
        (~is_rew, ~is_pun), 'cond').all('cond').sum('roi') / n_suj

    return xr.Dataset({
        'pun': is_pun_only, 'rew': is_rew_only, 'shared': is_rp,
        'none': is_none
    }).to_array('cond')


prop_split = conj_suj.groupby('roi').apply(compute_prop_suj)

# **Plot**
---

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20, 4), sharey=True, sharex=True)
axs = np.ravel(axs)

for n_r, r in enumerate(["aINS", "dlPFC", "lOFC", "vmPFC"]):
    plt.sca(axs[n_r])
    prop.sel(roi=r).plot(x='times', hue='cond', add_legend=False)
    plt.axvline(0., color='C3')
    plt.ylim(-1., 110)
    n_suj_r = anat_usuj.set_index('ma_checked').loc[r]['# subjects']
    plt.title(f"{r}\n" + r"($n_{subjects}=%i$)" % n_suj_r, fontweight='bold')
    plt.xlabel("Times [s]")
    plt.xticks([-.5, 0., 0.5, 1., 1.5])
    plt.xlim(-.5, 1.5)
    if n_r == 0:
        plt.ylabel('Proportion of\nsubjects [%]')

# create the legend
custom_lines = [
    Line2D([0], [0], color="C1", lw=6, solid_capstyle='round'),
    Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
]
pv_conds = ["RPE", "PPE"]

legend = plt.legend(
    custom_lines, pv_conds, ncol=2, bbox_to_anchor=(.6, -.05),
    fontsize=24, bbox_transform=fig.transFigure,
    title_fontproperties=dict(weight='bold', size=26)
)

In [None]:
save_to = cfg['export']['review']
cfg_export = cfg['export']['cfg']

fig.savefig(f'{save_to}/local_encoding_ss.png', **cfg_export)

In [None]:
prop.where(prop != 0).max("times")