In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study
from frites import set_mpl_style
from frites.conn import conn_links
from research.study.pblt import get_model, get_anat_table
from visbrain.utils import normalize

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
import seaborn as sns

from ipywidgets import interact

import json
with open("config.json", 'r') as f:
    cfg = json.load(f)

set_mpl_style('frites')

plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 25
plt.rcParams['axes.labelsize'] = 24

---
# **I/O**
## Load image of the task and anatomy

In [None]:
img_task = mpimg.imread('images/pblt_task.png')
img_anat = mpimg.imread('images/pblt_anat.png')

## Load behavioral informations

In [None]:
st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())
rp_ref = {
    '+0€': 'Rewarding', '+1€': 'Rewarding',
    '-0€': 'Punishing', '-1€': 'Punishing'
}

# load all of the behavioral tables
beh = []
for s in subjects:
    # load the subject specific behavioral table
    _df = st.search(s, folder='behavior', load=True).rename(
        columns={'outcome valence': 'outcome'})
    _df['subject'] = s
    
    # get the condition
    _df['condition'] = _df['outcome'].replace(rp_ref)
    
    # load the PE
    _df['PE'] = get_model('PE', s)
    beh.append(_df)
beh = pd.concat(beh).reset_index(drop=True)

## Number of contacts / roi

In [None]:
###############################################################################
use_roi = ['aINS', 'dlPFC', 'vmPFC', 'lOFC']
###############################################################################

anat, anat_dirty = [], []
for s in subjects:
    # load the bad channels
    bad_ch = st.search(s, folder='bad_channels', load=True)['ch_names']
    
    # load the anatomy
    _df = get_anat_table(s)
    anat_dirty.append(_df)
    keep = [c not in bad_ch for c in _df['contact'].values]
    _df = _df.iloc[keep, :]
    
    # keep only the important brain regions
    keep = [r in use_roi for r in _df['ma_checked'].values]
    _df = _df.iloc[keep, :]
    if not len(_df): continue
    
    _df['subject'] = s
    
    anat.append(_df)
anat = pd.concat(anat).reset_index(drop=True)
anat_dirty = pd.concat(anat_dirty).reset_index(drop=True)

df_contacts = anat.groupby('ma_checked').count()[['contact']].rename(
    columns={'contact': 'count'})
df_contacts_dirty = anat_dirty.groupby('ma_checked').count()[['contact']].rename(
    columns={'contact': 'count'})
n_ains = df_contacts.loc['aINS', 'count']
n_dlpfc = df_contacts.loc['dlPFC', 'count']
n_vmpfc = df_contacts.loc['vmPFC', 'count']
n_lofc = df_contacts.loc['lOFC', 'count']

In [None]:
df_cleaning = df_contacts.copy().loc[use_roi].rename(columns={'count': 'count (clean)'})
df_cleaning['count (dirty)'] = df_contacts_dirty.loc[use_roi, 'count']
df_cleaning['count (diff)'] = df_cleaning['count (dirty)'] - df_cleaning['count (clean)']
df_cleaning

# df_contacts, df_contacts_dirty.loc[use_roi]


## Load connectivity links

In [None]:
df_links = []
for s in subjects:
    # load the bad channels
    bad_ch = st.search(s, folder='bad_channels', load=True)['ch_names']
    
    # load the anatomy
    _df = get_anat_table(s, lr=True)
    keep = [c not in bad_ch for c in _df['contact'].values]
    _df = _df.iloc[keep, :]
    
    # get the hemispheres
    _df['hemi'] = [r[0] for r in _df['ma_checked'].values]
    _df['ma_checked'] = _df['ma_checked'].replace({
        'L\\_': '', 'R\\_': ''
    }, regex=True)
    
    # keep only the important brain regions
    keep = [r in use_roi for r in _df['ma_checked'].values]
    _df = _df.iloc[keep, :]
    if not len(_df): continue
    
    # get variables
    roi = _df['ma_checked'].values.astype(str)
    hemi = _df['hemi'].values
    
    # compute the number of within links
    roi_intra = conn_links(roi, roi_relation='intra', hemisphere=hemi,
                           hemi_links='intra')[1]
    if len(np.unique(roi)) > 2:
        roi_inter = conn_links(roi, roi_relation='inter', hemisphere=hemi,
                               hemi_links='intra', directed=False, sep='-')[1]
    else:
        roi_inter = np.array([])
    
    _df = pd.DataFrame({'links': np.r_[roi_intra, roi_inter]})
    df_links.append(_df)
    
df_links = pd.concat(df_links).reset_index(drop=True)

---
# **Data preparation**
## Prepare behavioral tables

In [None]:
# get the number of trials per condition
df_ntr = beh.groupby(['subject', 'outcome', 'condition']).count().reset_index()[[
    'subject', 'outcome', 'choice', 'condition']].rename(columns={'choice': '# Trials'})



## Prepare connectivity matrix

In [None]:
# get the number of connections / pairs
conn = df_links.groupby('links').size().reset_index().rename(columns={0: 'count'})

# split in (sources, targets)
sources, targets = [], []
for r in conn['links'].values:
    _s, _t = r.split('-')
    sources.append(_s)
    targets.append(_t)
conn['sources'], conn['targets'] = sources, targets
conn = conn.pivot(index='targets', columns='sources', values='count')

# make it symmetric
# for s in conn.index:
#     for t in conn.columns:
#         conn.loc[t, s] = conn.loc[s, t]
# conn = conn.astype("Int64")
conn


---
# **Make the figure**
## Plotting functions

In [None]:
# from scipy.stats import gaussian_kde

# def get_density(data):
#     density = gaussian_kde(data)
#     xs = np.linspace(data.min() * .1, data.max() * 1.2, 200)
#     density.covariance_factor = lambda : .3
#     density._compute_covariance()
#     return xs, density(xs)

# def plt_half_violin(cond):
#     outc = ['+1€', '+0€'] if cond == 'Rewarding' else ['-1€', '-0€']
#     offset = 1. if cond == 'Rewarding' else 2.
#     palette = ['C1', 'C5'] if cond == 'Rewarding' else ['C0', 'C5']

#     data = df_ntr[df_ntr['condition'] == cond]['# Trials']
#     xs, xsden = get_density(data)
#     xmean = np.percentile(data, 50)
#     xs_closest = np.abs(xs - xmean).argmin()
#     xsden = - normalize(xsden, 0., .5) + offset
#     plt.fill_betweenx(xs, [xsden.max()] * 200, xsden, color='C3')
#     plt.plot([xsden[xs_closest], offset], [xmean, xmean], color='w', linestyle='--')
    
#     for n_o, o in enumerate(outc):
#         data = df_ntr[df_ntr['outcome'] == o]['# Trials']
#         xmean = data.mean()
#         density = gaussian_kde(data)
#         xs, xsden = get_density(data)
#         xs_closest = np.abs(xs - xmean).argmin()
#         xsden = normalize(xsden, 0., .25) + offset
#         zorder = -1000 if '0' in o else 0
#         plt.fill_betweenx(xs, xsden, [xsden.min()] * 200, color=palette[n_o],
#                           alpha=.2, zorder=zorder)
#         plt.plot([xsden[xs_closest], offset], [xmean, xmean],
#                  color=palette[n_o], linestyle='--', lw=2)

# plt.figure()
# plt_half_violin('Rewarding')
# plt_half_violin('Punishing')
# plt.xticks([1, 2])
# plt.gca().set_xticklabels(['Rewarding', 'Punishing'])
# plt.ylabel('# Trials')
# plt.xlim(.5, 2.5)

## Actual plotting
### Version 1 : task before

In [None]:
###############################################################################
nb_kw = dict(size=27, weight='bold')
off_violin = 260
kw_legend = dict(marker='o', color='w', markersize=15)
###############################################################################


fig = plt.figure(figsize=(18, 18))
gs = GridSpec(4, 3, left=0.05, bottom=0.03, right=0.99, top=0.90,
              wspace=.15, hspace=0.25)

# =============================================================================
#                                       TASK
# =============================================================================
ax = plt.subplot(gs[0:2, 0:2])
plt.imshow(img_task)
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.text(*tuple([0., 1.05]), 'A', transform=ax.transAxes, **nb_kw)

# =============================================================================
#                               NUMBER OF TRIALS
# =============================================================================

# ________________________________ KDE DENSITY ________________________________
# plot the number of trials
# ax = plt.subplot(gs[0, -1])
# plt_half_violin('Rewarding')
# plt_half_violin('Punishing')
# plt.xticks([1, 2])
# plt.gca().set_xticklabels(['Rewarding', 'Punishing'])
# plt.ylabel('# Trials')
# plt.xlim(.5, 2.5)
# ax.text(*tuple([-0.25, 1.05]), 'B', transform=ax.transAxes, **nb_kw)

# ___________________________________ VIOLIN __________________________________
ax = plt.subplot(gs[0, -1])
sns.violinplot(
    data=df_ntr, x='outcome', y='# Trials', order=['-1€', '-0€', '+0€', '+1€'],
    palette=["#E24A33", "#f0a499", "#add0e4", "#348ABD"]
)
plt.plot([-.25, 1.25], [off_violin, off_violin], color='C0')
plt.plot([-.25, -.25], [off_violin - 5, off_violin], color='C0')
plt.plot([1.25, 1.25], [off_violin - 5, off_violin], color='C0')
plt.text(.5, off_violin + 10, 'Punishing', ha='center', fontsize=22, color='C0')
plt.plot([1.75, 3.25], [off_violin, off_violin], color='C1')
plt.plot([1.75, 1.75], [off_violin - 5, off_violin], color='C1')
plt.plot([3.25, 3.25], [off_violin - 5, off_violin], color='C1')
plt.text(2.5, off_violin + 10, 'Rewarding', ha='center', fontsize=22, color='C1')
plt.xlabel('Outcomes')
ax.text(*tuple([-0.25, 1.05]), 'B', transform=ax.transAxes, **nb_kw)

# =============================================================================
#                                       PE
# =============================================================================
# plot the pe
ax = plt.subplot(gs[1, -1])
sns.lineplot(
    data=beh, x='trial index', y='PE', hue='condition', n_boot=100,
    hue_order=['Punishing', 'Rewarding']
)
plt.autoscale(tight=True)
sns.move_legend(ax, "upper right", title=None, frameon=False)
# plt.ylim(-.7, .7)
plt.xlabel('Trial number')
plt.ylabel('Prediction Error')
ax.text(*tuple([-0.25, 1.05]), 'C', transform=ax.transAxes, **nb_kw)

# =============================================================================
#                                 ANATOMY
# =============================================================================
ax = plt.subplot(gs[2::, 0:2])
plt.imshow(img_anat)
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.text(*tuple([0., 1.05]), 'D', transform=ax.transAxes, **nb_kw)


custom_lines = [
    Line2D([0], [0], markerfacecolor="#1de8b5", label=f'aINS (n={n_ains})', **kw_legend),
    Line2D([0], [0], markerfacecolor="#42a5f5", label=f"dlPFC (n={n_dlpfc})", **kw_legend),
    Line2D([0], [0], markerfacecolor="#ef5350", label=f'lOFC (n={n_lofc})', **kw_legend),
    Line2D([0], [0], markerfacecolor="#ffca28", label=f'vmPFC (n={n_vmpfc})', **kw_legend)
]
plt.legend(
    handles=custom_lines, ncol=4, bbox_to_anchor=(.63, .03),
    fontsize=20, bbox_transform=fig.transFigure, handletextpad=0.001,
    columnspacing=0.3
)

# =============================================================================
#                              NUMBER OF LINKS
# =============================================================================

ax = plt.subplot(gs[-2::, -1])
sns.heatmap(
    conn, square=True, cbar_kws=dict(shrink=.55, label='# links'), annot=True,
    cmap='viridis', annot_kws=dict(fontweight='bold', fontsize=20), fmt='.0f'
)
plt.xlabel(''), plt.ylabel('')
ax.text(*tuple([-0.25, 1.05]), 'E', transform=ax.transAxes, **nb_kw)

### Version 2 : anatomy before

In [None]:
###############################################################################
nb_kw = dict(size=30, weight='bold')
off_violin = 260
kw_legend = dict(marker='o', color='w', markersize=15)
###############################################################################


fig = plt.figure(figsize=(18, 18))
gs = GridSpec(4, 3, left=0.05, bottom=0.03, right=0.99, top=0.90,
              wspace=.25, hspace=0.35)

# =============================================================================
#                                 ANATOMY
# =============================================================================
ax = plt.subplot(gs[0:2, 0:2])
plt.imshow(img_anat)
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.text(*tuple([0., 1.05]), 'A', transform=ax.transAxes, **nb_kw)


custom_lines = [
    Line2D([0], [0], markerfacecolor="#1de8b5",
           label=f'aINS (n={n_ains})', **kw_legend),
    Line2D([0], [0], markerfacecolor="#42a5f5",
           label=f"dlPFC (n={n_dlpfc})", **kw_legend),
    Line2D([0], [0], markerfacecolor="#ef5350",
           label=f'lOFC (n={n_lofc})', **kw_legend),
    Line2D([0], [0], markerfacecolor="#ffca28",
           label=f'vmPFC (n={n_vmpfc})', **kw_legend)
]
plt.legend(
    handles=custom_lines, ncol=4, bbox_to_anchor=(.63, .5),
    fontsize=20, bbox_transform=fig.transFigure, handletextpad=0.001,
    columnspacing=0.3
)

# =============================================================================
#                              NUMBER OF LINKS
# =============================================================================

ax = plt.subplot(gs[0:2, -1])
sns.heatmap(
    conn, square=True, cbar_kws=dict(shrink=.55, label='# links'), annot=True,
    cmap='viridis', annot_kws=dict(fontweight='bold', fontsize=20), fmt='.0f'
)
plt.xlabel(''), plt.ylabel('')
ax.text(*tuple([-0.25, 1.05]), 'B', transform=ax.transAxes, **nb_kw)

# =============================================================================
#                                       TASK
# =============================================================================
ax = plt.subplot(gs[2::, 0:2])
plt.imshow(img_task)
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.text(*tuple([0., 1.0]), 'C', transform=ax.transAxes, **nb_kw)

# =============================================================================
#                               NUMBER OF TRIALS
# =============================================================================

# ________________________________ KDE DENSITY ________________________________
# plot the number of trials
# ax = plt.subplot(gs[2, -1])
# plt_half_violin('Rewarding')
# plt_half_violin('Punishing')
# plt.xticks([1, 2])
# plt.gca().set_xticklabels(['Rewarding', 'Punishing'])
# plt.ylabel('# Trials')
# plt.xlim(.5, 2.5)
# ax.text(*tuple([-0.25, 1.05]), 'D', transform=ax.transAxes, **nb_kw)

# ___________________________________ VIOLIN __________________________________
ax = plt.subplot(gs[2, -1])
sns.violinplot(
    data=df_ntr, x='outcome', y='# Trials', order=['+1€', '+0€', '-0€', '-1€'],
    palette=["#348ABD", "#add0e4", "#f0a499", "#E24A33"]
)
# rew
plt.plot([-.25, 1.25], [off_violin, off_violin], color='C1')
plt.plot([-.25, -.25], [off_violin - 5, off_violin], color='C1')
plt.plot([1.25, 1.25], [off_violin - 5, off_violin], color='C1')
plt.text(.5, off_violin + 10, 'Rewarding', ha='center', fontsize=22, color='C1')
# pun
plt.plot([1.75, 3.25], [off_violin, off_violin], color='C0')
plt.plot([1.75, 1.75], [off_violin - 5, off_violin], color='C0')
plt.plot([3.25, 3.25], [off_violin - 5, off_violin], color='C0')
plt.text(2.5, off_violin + 10, 'Punishing', ha='center', fontsize=22, color='C0')

plt.xlabel('Outcomes')
ax.text(*tuple([-0.25, 1.0]), 'D', transform=ax.transAxes, **nb_kw)
ax.set_xticklabels(['+1€', '0€', '0€', '-1€'])
plt.yticks([0, 50, 100, 150, 200])

# =============================================================================
#                                       PE
# =============================================================================
# plot the pe
ax = plt.subplot(gs[3, -1])
sns.lineplot(
    data=beh.replace({'Rewarding': 'RPE', 'Punishing': 'PPE'}), x='trial index',
    y='PE', hue='condition', n_boot=100, hue_order=['PPE', 'RPE']
)
plt.autoscale(tight=True)
sns.move_legend(ax, "upper right", title=None, frameon=False)
# plt.ylim(-.7, .7)
plt.xlabel('Trials')
plt.ylabel('Prediction Error')
ax.text(*tuple([-0.25, 1.05]), 'E', transform=ax.transAxes, **nb_kw);



In [None]:
save_to = cfg['export']['save_to']
cfg_export = cfg['export']['cfg']

fig.savefig(f'{save_to}/fig_anat_beh.png', **cfg_export)

In [None]:
def ntr_stats(x):
    return pd.DataFrame({
        "Mean": [int(np.round(x['# Trials'].mean()))],
        "Std": [int(np.round(x['# Trials'].std()))]
    })


df_ntr.groupby('outcome').apply(ntr_stats)


In [None]:
142 + 93

In [None]:
# df_ntr.groupby(['subject', 'condition']).apply(ntr_stats)
df_ntr.groupby(['subject', 'condition']).sum().reset_index().groupby('condition').mean()

In [None]:
df_ntr#.groupby('condition').mean()


In [None]:
sns.stripplot(
    data=beh.replace({'Rewarding': 'RPE', 'Punishing': 'PPE'}), x='trial index',
    y='PE', hue='condition', hue_order=['PPE', 'RPE']
)
# beh.replace({'Rewarding': 'RPE', 'Punishing': 'PPE'})
