In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from pathta import Study

from frites import set_mpl_style

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

from ipywidgets import interact

import json
with open("config.json", 'r') as f:
    config = json.load(f)

set_mpl_style()


---
# **I/O**
## load the significant results

In [None]:
###############################################################################
model = 'PE'
band = 'lga'
from_folder = f'mi/group/{model}'
###############################################################################

st = Study('PBLT')

# load significance testing
f = st.search('savgol-10.nc', band, folder=from_folder)
assert len(f) == 1
dt = xr.load_dataset(f[0])

# load conjunction
f = st.search('conj', band, folder=from_folder)
assert len(f) == 1
conj = xr.load_dataset(f[0]).to_array('cond').sel(cond=['pun', 'rew'])

model = model.replace('_', ' | ')


## Compute proportions

In [None]:
def atoms(x):
    # get total number of contacts
    n_contacts = len(x['roi'])

    # find where there are significant effects
    x_s = (x < 0.05).any('times')
    
    # compute proportion of task-relevant contacts
    perc_rel = 100 * (x_s.any('cond')).sum() / n_contacts
    
    # infer proportion of task-irrelevant contacts
    perc_irr = 100 - perc_rel
    
    # find proportion of reward specific only
    is_rew, is_pun = x_s.sel(cond='rew'), x_s.sel(cond='pun')
    perc_rpe = 100 * xr.concat(
        (is_rew, ~is_pun), 'cond').all('cond').sum('roi').data / n_contacts
    perc_ppe = 100 * xr.concat(
        (~is_rew, is_pun), 'cond').all('cond').sum('roi').data / n_contacts
    
    # proportion in both
    perc_mix = 100 * xr.concat(
        (is_rew, is_pun), 'cond').all('cond').sum('roi').data / n_contacts
    
    # merge everything
    res = xr.Dataset({
        # 'tot_relevant': perc_rel,
        'irrelevant': perc_irr,
        'rpe': perc_rpe,
        'mix': perc_mix,
        'ppe': perc_ppe,
    }).to_array('atoms')
    return res

proportions = conj.groupby('roi').apply(atoms)


---
# **Plotting**
## Plotting function

In [None]:
def plot_node(name, perc, center=(0, 0), fz=22):
    """Nodes as donut plot.
    
    name : brain region name
    perc : [perc_irrelevant, perc_rpe, perc_mix, perc_ppe]
    """
    colors = ['#9999', 'C1', 'C2', 'C0']

    # start angle of the task-irrelevant proportion
    startangle = 90 - 360 * (perc[0] / 2) / 100
    
    # pie plot
    patches, texts = plt.pie(
        perc, colors=colors, startangle=startangle, wedgeprops=dict(width=0.47),
        center=center, labels=[str(int(np.round(k))) + "%" for k in perc],
        textprops=dict(color="w", fontweight='bold', fontsize=16,
                       ha='center', va='center'), labeldistance=.77
    )
    plt.text(
        center[0], center[1], name, ha='center', va='center', fontsize=fz,
        fontweight='bold', color='C3'
    )


# plot_node('vmPFC', [30, 40, 20, 10])


## Actual plot

In [None]:
###############################################################################
dist = 3.
ycoef = .9
ydelta = .12
centers = {  # center arround 0
    "aINS": (-dist, 0),
    "dlPFC": (0, ycoef * dist),
    "vmPFC": (0, -ycoef * dist),
    "lOFC": (dist, 0)
}
###############################################################################


fig = plt.figure(figsize=(10, 4))

# add main nodes
for roi, center in centers.items():
    plot_node(roi, proportions.sel(roi=roi).data, center=center)

style = "Simple, tail_width=0.5, head_width=20, head_length=16"
kw = dict(arrowstyle=style, clip_on=False, lw=8)

# add arrows
a_ppe = patches.FancyArrowPatch(
    (-dist, 1), (-1, ycoef * dist), connectionstyle="arc3,rad=-.3",
    color='C0', **kw
)
a_rpe = patches.FancyArrowPatch(
    (1, -ycoef * dist), (dist, -1), connectionstyle="arc3,rad=.3",
    color='C1', **kw
)
a_mix = patches.FancyArrowPatch(
    (0, -2 * ycoef + ydelta), (0, 2 * ycoef - ydelta),  # connectionstyle="arc3,rad=.3",
    color='C5', clip_on=False, lw=8, linestyle='--'
)

# add additional text
kw_txt = dict(fontsize=24, fontweight='bold', clip_on=False, ha='center', va='center')
plt.text(-1.8, 1.6, 'Redundancy\nPPE', rotation=45, color='C0', **kw_txt)
plt.text(2., -1.8, 'Redundancy\nRPE', rotation=45, color='C1', **kw_txt)
plt.text(0., 0, 'Synergy\n\nPE', rotation=90, color='C5', **kw_txt)


plt.gca().add_patch(a_ppe)
plt.gca().add_patch(a_rpe)
plt.gca().add_patch(a_mix)

# create the legend
custom_leg = [
    Patch(facecolor='C1', label='RPE-specific'),
    Patch(facecolor='C2', label='Mix R/PPE'),
    Patch(facecolor='C0', label='PPE-specific'),
    Patch(facecolor='#9999', label='Task-irrelevant'),
    # Line2D([0], [0], color="C0", lw=6, solid_capstyle='round'),
    # Line2D([0], [0], color="C1", lw=6, solid_capstyle='round')
]

legend = plt.legend(
    handles=custom_leg, ncol=2, bbox_to_anchor=(.6, -0.65),
    fontsize=20, bbox_transform=fig.transFigure, title="Proportion of significant sEEG contacts over time (p<0.05)",
    title_fontproperties=dict(weight='bold', size=18)
)


In [None]:
save_to = config['export']['save_to']
cfg_export = config['export']['cfg']

fig.savefig(f'{save_to}/fig_summary.png', **cfg_export)
