In [None]:
import os

import numpy as np
import xarray as xr
import pandas as pd

from mne.utils import ProgressBar

from frites.dataset import DatasetEphy
from frites.workflow import WfMi

import matplotlib.pyplot as plt
from ipywidgets import interact

plt.style.use('seaborn-dark')
plt.style.use('seaborn-poster')

---
# **--- ROOT PATH ---**

<div class="alert alert-info"><p>

Define the path to where the data are located !
</p></div>

In [None]:
ROOT = '/run/media/etienne/DATA/Toolbox/BraiNets/CookingFrites/dataset/'

---
# **0 - Functions**

In [None]:
###############################################################################
###############################################################################
#                 Load the data of a single subject
###############################################################################
###############################################################################

def load_ss(subject_nb):
    """Load the data of a single subject.
    
    Parameters
    ----------
    subject_nb : int
        Subject number [0, 12]
    
    Returns
    -------
    hga : xarray.DataArray
        Xarray containing the high-gamma activity
    anat : pandas.DataFrame
        Table containing the anatomical informations
    beh : pandas.DataFrame
        Table containing the behavioral informations
    """
    # load the high-gamma activity
    file_hga = os.path.join(ROOT, 'hga', f'hga_s-{subject_nb}.nc')
    hga = xr.load_dataarray(file_hga)

    # load the name of the brain regions
    file_anat = os.path.join(ROOT, 'anat', f'anat_s-{subject_nb}.xlsx')
    anat = pd.read_excel(file_anat)

    # load the behavior
    file_beh = os.path.join(ROOT, 'beh', f'beh_s-{subject_nb}.xlsx')
    beh = pd.read_excel(file_beh)
    
    return hga, anat, beh


###############################################################################
###############################################################################
#                 Load the data of multiple subjects
###############################################################################
###############################################################################

def load_ms(s_range=[0, 11], model='outcome', condition='rew',
            space='channels', mean_roi=True, prepend_suj_to_ch=True):
    """Load multiple subjects.
    
    Parameters
    ----------
    s_range : int or list
        Subjects to load. Use either an integer (e.g. 7) to load a single
        subject or a range of subjects (e.g. [5, 10])
    model : {'outcome', 'pe', 'rt'}
        Model to use. Use either :
        
            * 'outcome' : find differences in the neural activity between the
              outcomes
            * 'pe' : find regions with an activity correlating with the
              prediction error
            * 'rt' : find regions with an activity correlating with the
              reaction time
    condition : {'rew', 'pun', 'context', 'null'}
        Condition to load. Use either :
        
            * 'rew' : for outcomes {+0€; +1€}
            * 'pun' : for outcomes {-1€; -0€}
            * 'context' : for outcomes {-1€; +1€}
            * 'null' : for outcomes {-0€; +0€}
    space : {'channels', 'roi'}
        Specify if the spatial dimension should be described with channel names
        or with brain region names
    mean_roi : bool
        Specify if you want to take the mean high-gamma activity inside a brain
        region
    prepend_suj_to_ch : bool
        Add subject name to each channel name
    
    Returns
    -------
    hga : list
        List of high-gamma activity across subjects
    """
    # inputs checking
    if isinstance(s_range, int):
        s_range = [s_range, s_range]
    s_range[1] += 1
    s_range[0], s_range[1] = max(s_range[0], 0), min(s_range[1], 12)
    mesg = f"Subject %i | model={model} | condition={condition} | space={space}"
    pbar = ProgressBar(range(s_range[0], s_range[1]), mesg=mesg % 0)
    model = model.lower()
    assert space in ['channels', 'parcels', 'roi']
    
    # get the code of the condition
    outc = {
        'rew': (+1, +2),
        'pun': (-2, -1),
        'context': (-2, +2),
        'null': (-1, +1)
    }[condition]
    
    # get the behavioral column to use
    col = {
        'outcome': 'code',
        'pe': 'PE',
        'rt': 'RT'
    }[model]
    
    # load the data
    hga = []
    for n_s in range(s_range[0], s_range[1]):
        pbar._tqdm.desc = mesg % n_s
        # load the data of a single subject
        _hga, _anat, _beh = load_ss(n_s)
        _outc = _hga['trials'].data
        _ch = _hga['channels'].data
        
        # replace trial dimension with the model
        _hga = _hga.rename(trials=model)
        _hga[model] = list(_beh[col])
        
        # get which outcome to keep
        keep_outc = np.logical_or(_outc == outc[0], _outc == outc[1])
        _hga = _hga[keep_outc, ...]
        
        # replace with brain regions
        if space in ['parcels', 'roi']:
            _hga = _hga.rename(channels=space)
            _hga[space] = list(_anat['roi'])
            
            # take the mean of the hga per parcel
            if mean_roi:
                _hga = _hga.groupby(space).mean(space)
        elif prepend_suj_to_ch and (space == 'channels'):
            # prepend subject number to channel name
            _hga['channels'] = [f"suj{n_s}/{c}" for c in _ch]
        
        # ascontinuous array
        _hga.data = np.ascontiguousarray(_hga.data)
        
        hga.append(_hga)
        pbar.update_with_increment_value(1)

    return hga


###############################################################################
###############################################################################
#                           Plotting the results
###############################################################################
###############################################################################

def _single_subplot(mi, pv, pv_on_top, pv_pos, lw=5, color='r', alpha=0.05,
                    label=None, xlabel=True, ylabel=True):
    """Single subplot plotting function."""
    times = mi['times'].data
    
    if pv_on_top:
        p = np.full((len(times),), pv_pos)
        p[pv.data >= alpha] = np.nan
        plt.plot(times, mi.data, color=color, label=label, lw=1.5)
        plt.plot(times, p, color=color, lw=lw)
    else:
        mi_s = mi.copy()
        mi_s.data[pv >= alpha] = np.nan
        plt.plot(times, mi.data, color=color, label=label, lw=1.5)
        plt.plot(times, mi_s.data, color=color, lw=lw)
    plt.axvline(0., color='k', linestyle='--', lw=1)
    if xlabel: plt.xlabel('Times')
    if ylabel: plt.ylabel('MI (bits)')
    plt.xlim(-.5, 1.5)
    plt.grid(True)


def plot_results(mi, pv, alpha=0.05, split=True, pv_on_top=True, lw=5):
    """Plot significant results.
    
    Parameters
    ----------
    mi, pv : xr.DataArray
        Measure of information and corrected p-values of shape (n_times, n_roi)
    alpha : float
        Significiency threshold
    split : bool
        Specify whether results should be presented in splitted subplots (True)
        or superimposed in the same subplot (False)
    pv_on_top : bool
        Specify whether the significant values should be plotted on top (True)
        or along the line
    lw : float
        Line width of significant results
    """
    times, roi = mi['times'].data, np.sort(mi['roi'].data)
    mimax, mimin = mi.data.max(), mi.data.min()
    dp = (mimax - mimin) / 20.
    
    if split:
        # figure creation
        n_per_row = 5
        ncols = min(n_per_row, len(roi))
        nrows = int(np.ceil(len(roi) / n_per_row))
        width, height = int(5 * ncols), int(4 * nrows)
        fig, axs = plt.subplots(
            nrows=nrows, ncols=ncols, sharex=True, sharey=True,
            figsize=(width, height))
        axs = np.ravel(axs)

        # subplot filling
        for n_r in range(len(axs)):
            plt.sca(axs[n_r])
            if n_r >= len(roi):
                plt.axis(False)
                continue
            r = roi[n_r]
            _single_subplot(
                mi.sel(roi=r), pv.sel(roi=r), pv_on_top, mimax + dp,
                lw=lw, color=f'C{n_r}', alpha=alpha, label=r,
                xlabel=n_r >= len(roi) - n_per_row, ylabel=n_r % n_per_row == 0
            )
            fw = 'bold' if np.any(pv.sel(roi=r).data < 0.05) else None
            plt.title(r, fontweight=fw)
    else:
        # subplot filling
        fig = plt.figure(figsize=(15, 8))
        for n_r, r in enumerate(roi):
            _single_subplot(
                mi.sel(roi=r), pv.sel(roi=r), pv_on_top, mimax + dp * (n_r + 1),
                lw=lw, color=f'C{n_r}', alpha=alpha, label=r
            )
        plt.legend()
    
    # add figure title
    attrs = mi.attrs
    mi_type, inference, mcp = attrs['mi_type'], attrs['inference'], attrs['mcp']
    fig.suptitle(
        (f"Significant results using {inference.upper()} model and p-values "
         f"corrected using {mcp} (p < {alpha}; mi_type={mi_type})"),
        fontweight='bold', fontsize=18, y=1.02# + int(split)
    )

---
# **1. Fixed and Random effect**
## 1.1 Differences between the zeros

In [None]:
# loading all
hga = load_ms(model='outcome', condition='null', space='roi')
ds = DatasetEphy(hga, y='outcome', times='times', roi='roi')

# run the computations and stats
wf = WfMi(inference='rfx', mi_type='cd')
mi, pv = wf.fit(ds, n_perm=50, n_jobs=-1, random_state=0)

In [None]:
plot_results(mi, pv, split=True)

"""
based on this results, the answer seems to be "no", there's no differences of
neural activity between outcomes -0€ and +0€
"""

## 1.2 Correlation with reaction time 

In [None]:
# loading all
hga = load_ms(model='rt', condition='rew', space='roi')
ds = DatasetEphy(hga, y='rt', times='times', roi='roi')

# run the computations and stats
wf = WfMi(inference='ffx', mi_type='cc')
mi, pv = wf.fit(ds, n_perm=50, n_jobs=-1, random_state=0)

In [None]:
plot_results(mi, pv, split=True)

---
# **2. Correcting for multiple comparisons**
## 2.1 The single time-point problem (advanced)

In [None]:
# load the data
hga = load_ms(s_range=2, model='outcome', condition='null', space='channels',
              prepend_suj_to_ch=False)

# build the baseline dataset
hga_tr = []
for k in range(len(hga)):
    # build baseline and period-of-interest high-gamma
    hga_bsl = hga[k].sel(times=0.).drop('times')
    hga_poi = hga[k].sel(times=.25).drop('times')
    
    # rename the outcome dimension
    hga_bsl = hga_bsl.rename(outcome='period')
    hga_poi = hga_poi.rename(outcome='period')
    
    # use the code 0 = baseline; 1 = poi
    hga_bsl['period'] = [0] * len(hga_bsl['period'])
    hga_poi['period'] = [1] * len(hga_poi['period'])
    
    # concatenate both and expand with a time dimension
    _hga_tr = xr.concat([hga_bsl, hga_poi], 'period')
    _hga_tr = _hga_tr.expand_dims('times', axis=2)
    _hga_tr['times'] = [0.]
    
    # append to the full list
    hga_tr.append(_hga_tr)

# build the DatasetEphy
ds = DatasetEphy(hga_tr, y='period', times='times', roi='channels')

# run the computations and stats
wf = WfMi(inference='ffx', mi_type='cd')
mi, pv = wf.fit(ds, n_perm=50, n_jobs=-1, random_state=0, mcp='maxstat')

In [None]:
signi_c = mi.squeeze()[pv.squeeze() < 0.05]['roi'].data.tolist()

print(f"List of significant contacts : {', '.join(signi_c)}")