In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style
from research.study.pblt import get_anat_table
from frites.utils import nonsorted_unique

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from nilearn import plotting
import seaborn as sns

set_mpl_style()
import json
with open("config.json", 'r') as f:
    config = json.load(f)

---
# **I/O**
## load the significant results

In [None]:
###############################################################################
model = 'PE'
use_roi = ['aINS', 'dlPFC', 'vmPFC', 'lOFC']
band = 'lga'
from_folder = f'mi/group/{model}'
###############################################################################

st = Study('PBLT')

# load significance testing
f = st.search('savgol-10.nc', band, folder=from_folder)
assert len(f) == 1
dt = xr.load_dataset(f[0])

# load conjunction
f = st.search('conj', band, folder=from_folder)
assert len(f) == 1
conj = xr.load_dataset(f[0]).to_array('cond').sel(cond=['pun', 'rew'])

model = model.replace('_', ' | ')

# load (x, y, z) coordinates
subjects = list(st.load_config('subjects.json').keys()) 
anat = []
for s in subjects:
    # load the bad channels
    bad_ch = st.search(s, folder='bad_channels', load=True)['ch_names']
    
    # load the anatomy
    _df = get_anat_table(s)
    keep = [c not in bad_ch for c in _df['contact'].values]
    _df = _df.iloc[keep, :]
    
    # keep only the important brain regions
    keep = [r in use_roi for r in _df['ma_checked'].values]
    _df = _df.iloc[keep, :]
    if not len(_df): continue
    
    _df['subject'] = s
    
    anat.append(_df)
anat = pd.concat(anat).reset_index(drop=True)

# reorder the anat table
u_roi = nonsorted_unique(conj['roi'].data)

anat_n = []
for r in u_roi:
    anat_n.append(anat.set_index('ma_checked').loc[r].reset_index())
anat = pd.concat(anat_n).reset_index(drop=True)

np.testing.assert_array_equal(
    conj['roi'].data, anat['ma_checked'].values
)

## Compute proportions

In [None]:
conj_bin = (conj < 0.05).any('times')
is_rew, is_pun = conj_bin.sel(cond='rew'), conj_bin.sel(cond='pun')

# pure reward / punishment
rew = xr.concat((is_rew, ~is_pun), 'cond').all('cond').data
pun = xr.concat((~is_rew, is_pun), 'cond').all('cond').data
rewpun = xr.concat((is_rew, is_pun), 'cond').all('cond').data

# get role
role = ['None'] * len(rew)
for k in range(len(role)):
    if rewpun[k]:
        role[k] = 'R/PPE'
    elif rew[k]:
        role[k] = 'RPE'
    elif pun[k]:
        role[k] = 'PPE'
# role

# fill the anat table with roles and colors
anat_sub = anat[['contact', 'X', 'Y', 'Z', 'subject', 'ma_checked']].copy()
anat_sub['abs_X'] = np.abs(anat_sub['X'])
anat_sub['role'] = role
anat_sub['color'] = anat_sub['role'].replace({
    'RPE': '#348ABD',
    'PPE': '#E24A33',
    'R/PPE': '#2ECC71',
    'None': '#999999'
})
# anat_sub


In [None]:
def gp_fcn(x):
    n_contacts = len(x)
    prop = np.round(100 * x.groupby('role').size() / n_contacts, 2).reset_index().rename(columns={0: 'perc'})
    return prop
    # prop[0] = 'perc'
    print(prop)
    0/0

anat_sub.groupby('ma_checked').apply(gp_fcn)


---
# **Plot**
## Nilearn glass brain plotting

In [None]:
###############################################################################
use_roi = ['aINS', 'dlPFC', 'lOFC', 'vmPFC']
# use_roi = ['aINS']
# use_roi = ['vmPFC', 'lOFC']
# use_roi = ['dlPFC']
# use_roi = ['dlPFC', 'lOFC', 'vmPFC']
# use_roi = ['dlPFC', 'lOFC', 'vmPFC', 'aINS']
###############################################################################

fig = plt.figure(figsize=(20, 10))
display = plotting.plot_glass_brain(
    None, display_mode='lzry', figure=fig, axes=(0., .2, 1., .8), title=None
    # None, display_mode='ortho', figure=fig, axes=(0., .2, 1., .8)
)


anat_plt = anat_sub.set_index('ma_checked').loc[use_roi]
roles = anat_plt['role']
xyz = anat_plt[['X', 'Y', 'Z']].values
plt_cfg = {
    'RPE': {
        'size': 70,
        'color': '#348ABD'
    },
    'PPE': {
        'size': 70,
        'color': '#E24A33'
    },
    'R/PPE': {
        'size': 70,
        'color': '#2ECC71'
    },
    'None': {
        'size': 0,
        'color': '#999999'
    }
}

for cond in ['RPE', 'PPE', 'R/PPE', 'None']:
    print(cond, len(xyz[roles == cond, :].tolist()))
    display.add_markers(
        xyz[roles == cond, :].tolist(), marker_color=plt_cfg[cond]['color'],
        marker_size=plt_cfg[cond]['size']
    )

# create the legend
kw_lines = dict(marker='o', color='w', markersize=15)
custom_lines = [
    Line2D([0], [0], markerfacecolor="C1", **kw_lines),
    Line2D([0], [0], markerfacecolor="C0", **kw_lines),
    Line2D([0], [0], markerfacecolor="C5", **kw_lines)
]
titles = ["RPE specific", "PPE specific", "R/PPE"]
plt.legend(
    custom_lines, titles, ncol=3, bbox_to_anchor=(.7, 0.25), fontsize=20,
    bbox_transform=fig.transFigure, title="Significant sEEG contacts (p<0.05)",
    title_fontproperties=dict(weight='bold', size=20)
);


## Export the figure

In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/supp_anat_repartition.png', **cfg_export)


In [None]:
4 * 1500 / 100