In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study
from frites import set_mpl_style
from frites.conn import conn_links
from research.study.pblt import get_model, get_anat_table
from visbrain.utils import normalize

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
import seaborn as sns

from ipywidgets import interact

import json
with open("../config.json", 'r') as f:
    cfg = json.load(f)

set_mpl_style('frites')

plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 25
plt.rcParams['axes.labelsize'] = 24

# **I/O**
---
## Number of recordings per subject

In [None]:
###############################################################################
use_roi = ['aINS', 'dlPFC', 'vmPFC', 'lOFC']
###############################################################################

st = Study('PBLT')
subjects = list(st.load_config('subjects.json').keys())

# ------------------------------- GROUP ANATOMY -------------------------------

anat = []
for n_s, s in enumerate(subjects):
    # load the bad channels
    bad_ch = st.search(s, folder='bad_channels', load=True)['ch_names']
    
    # load the anatomy
    _df = get_anat_table(s)
    keep = [c not in bad_ch for c in _df['contact'].values]
    _df = _df.iloc[keep, :]

    # get hemispheres
    __df = get_anat_table(s, lr=True).iloc[keep, :]
    _df['Hemi'] = [r[0] for r in __df["ma_checked"].values]

    # keep only the important brain regions
    keep = [r in use_roi for r in _df['ma_checked'].values]
    _df = _df.iloc[keep, :].loc[:, ['ma_checked', 'contact', 'Hemi']]
    if not len(_df): continue
    
    _df['subject'] = s
    _df['subject #'] = n_s + 1
    
    anat.append(_df)

# concatenate subjects
anat = pd.concat(anat).reset_index(drop=True)

# -------------------------------- # SUBJECTS ---------------------------------

# number of subject per subject and brain region
anat_suj = anat.groupby([
    'subject', 'subject #', 'ma_checked'
]).count().reset_index()

# number of subjects per brain region
def prop_unique_suj(df):
    """Compute the proportion of unique subject per roi."""
    n_suj = len(np.unique(df['subject']))
    return pd.DataFrame({
        '# subjects': [n_suj],
        'ma_checked': df['ma_checked'].values[0],
    })

anat_usuj = anat_suj.groupby('ma_checked').apply(
    prop_unique_suj
).reset_index(drop=True)

# ----------------------------- # SUBJECTS / CONN -----------------------------

# compute number of links per subject


def compute_links(x):
    _, roi_st = conn_links(
        x["ma_checked"].values, directed=False, hemisphere=x["Hemi"].values,
        hemi_links="intra", verbose=False
    )
    return pd.DataFrame({"roi": roi_st})


def compute_prop_links(x):
    return pd.DataFrame({"Count": [len(np.unique(x["subject"]))]})


links = anat.groupby('subject').apply(compute_links).reset_index()
prop_links = links.groupby("roi").apply(compute_prop_links).reset_index()

# fill connectivity array

roi_order = ["aINS", "dlPFC", "lOFC", "vmPFC"]

conn = xr.DataArray(
    np.full((len(roi_order), len(roi_order)), np.nan),
    dims=('sources', 'targets'),
    coords=(roi_order, roi_order)
)

for n_s, s in enumerate(roi_order):
    for n_t, t in enumerate(roi_order):
        if s < t:
            continue
        try:
            conn.data[n_s, n_t] = prop_links.set_index(
                "roi").loc[f"{s}-{t}", "Count"]
        except:
            conn.data[n_s, n_t] = prop_links.set_index(
                "roi").loc[f"{t}-{s}", "Count"]

# **Plot**
---

In [None]:
###############################################################################
palette = [
    "#1de8b5", "#42a5f5", "#ef5350", "#ffca28",
]
nb_kw = dict(size=27, weight='bold')
###############################################################################

fig = plt.figure(figsize=(24, 12))

gs = GridSpec(2, 8, figure=fig, wspace=.7, hspace=.4)


# ---------------------------- # SUBJECTS PER ROI -----------------------------

ax1 = fig.add_subplot(gs[0, 2:4])
ax = sns.barplot(
    anat_usuj, x="ma_checked", y="# subjects", palette=palette, ax=ax1
)
ax.bar_label(ax.containers[0], fontsize=13, fontweight='bold', color='C3');
plt.xlabel('')
ax.text(*tuple([-.3, 1.1]), 'A', transform=ax.transAxes, **nb_kw)
plt.yticks(np.arange(2, 15, 2))

ax2 = fig.add_subplot(gs[0, 4:6])
sns.heatmap(
    conn.to_pandas(), square=True, cbar_kws=dict(shrink=.55, label='# subjects'),
    annot=True, cmap='viridis', annot_kws=dict(fontweight='bold', fontsize=20),
    fmt='.0f', cbar=False, ax=ax2
)
plt.xlabel("")
plt.ylabel("")

# ------------------------ # SITES PER SUBJECT PER ROI ------------------------

ax3 = fig.add_subplot(gs[1, :])

ax = sns.barplot(
    anat_suj, x="subject #", y="contact", hue="ma_checked", palette=palette,
    ax=ax3
)
plt.yticks(np.arange(2, 21, 2))
for k in range(4):
    ax.bar_label(
        ax.containers[k], fontsize=13, fontweight='bold', color=palette[k]
    );
# sns.move_legend(ax, "upper left", bbox_to_anchor=(1, .75))
sns.move_legend(ax, "upper right")
plt.ylabel("# iEEG sites")
plt.xlabel("Subject #")
for h in range(0, 16):
    plt.axvline(h + .5, color='lightgray', linestyle='--', alpha=.5)
ax.text(*tuple([-.07, 1.1]), 'B', transform=ax.transAxes, **nb_kw)
ax.legend_.set_title(None)


In [None]:
save_to = cfg['export']['review']
cfg_export = cfg['export']['cfg']

fig.savefig(f'{save_to}/anat_ss.png', **cfg_export)

In [None]:
30 * 11 / 100