In [None]:
import os

import numpy as np
import xarray as xr
import pandas as pd

from mne.utils import ProgressBar

from frites.dataset import DatasetEphy
from frites.workflow import WfMi

import matplotlib.pyplot as plt

---
# **--- ROOT PATH ---**

<div class="alert alert-info"><p>

Define the path to where the data are located !
</p></div>

In [None]:
ROOT = '/run/media/etienne/DATA/Toolbox/BraiNets/CookingFrites/dataset/'

# **0 - Function**

In [None]:
###############################################################################
###############################################################################
#                 Load the data of a single subject
###############################################################################
###############################################################################

def load_ss(subject_nb):
    """Load the data of a single subject.
    
    Parameters
    ----------
    subject_nb : int
        Subject number [0, 12]
    
    Returns
    -------
    hga : xarray.DataArray
        Xarray containing the high-gamma activity
    anat : pandas.DataFrame
        Table containing the anatomical informations
    beh : pandas.DataFrame
        Table containing the behavioral informations
    """
    # load the high-gamma activity
    file_hga = os.path.join(ROOT, 'hga', f'hga_s-{subject_nb}.nc')
    hga = xr.load_dataarray(file_hga)

    # load the name of the brain regions
    file_anat = os.path.join(ROOT, 'anat', f'anat_s-{subject_nb}.xlsx')
    anat = pd.read_excel(file_anat)

    # load the behavior
    file_beh = os.path.join(ROOT, 'beh', f'beh_s-{subject_nb}.xlsx')
    beh = pd.read_excel(file_beh)
    
    return hga, anat, beh


###############################################################################
###############################################################################
#                 Load the data of multiple subjects
###############################################################################
###############################################################################

def load_ms(s_range=[0, 11], model='outcome', condition='rew',
            space='channels', mean_roi=True, prepend_suj_to_ch=True):
    """Load multiple subjects.
    
    Parameters
    ----------
    s_range : int or list
        Subjects to load. Use either an integer (e.g. 7) to load a single
        subject or a range of subjects (e.g. [5, 10])
    model : {'outcome', 'pe', 'rt'}
        Model to use. Use either :
        
            * 'outcome' : find differences in the neural activity between the
              outcomes
            * 'pe' : find regions with an activity correlating with the
              prediction error
            * 'rt' : find regions with an activity correlating with the
              reaction time
    condition : {'rew', 'pun', 'context', 'null'}
        Condition to load. Use either :
        
            * 'rew' : for outcomes {+0€; +1€}
            * 'pun' : for outcomes {-1€; -0€}
            * 'context' : for outcomes {-1€; +1€}
            * 'null' : for outcomes {-0€; +0€}
    space : {'channels', 'roi'}
        Specify if the spatial dimension should be described with channel names
        or with brain region names
    mean_roi : bool
        Specify if you want to take the mean high-gamma activity inside a brain
        region
    prepend_suj_to_ch : bool
        Add subject name to each channel name
    
    Returns
    -------
    hga : list
        List of high-gamma activity across subjects
    """
    # inputs checking
    if isinstance(s_range, int):
        s_range = [s_range, s_range]
    s_range[1] += 1
    s_range[0], s_range[1] = max(s_range[0], 0), min(s_range[1], 12)
    mesg = f"Subject %i | model={model} | condition={condition} | space={space}"
    pbar = ProgressBar(range(s_range[0], s_range[1]), mesg=mesg % 0)
    model = model.lower()
    assert space in ['channels', 'parcels', 'roi']
    
    # get the code of the condition
    outc = {
        'rew': (+1, +2),
        'pun': (-2, -1),
        'context': (-2, +2),
        'null': (-1, +1)
    }[condition]
    
    # get the behavioral column to use
    col = {
        'outcome': 'code',
        'pe': 'PE',
        'rt': 'RT'
    }[model]
    
    # load the data
    hga = []
    for n_s in range(s_range[0], s_range[1]):
        pbar._tqdm.desc = mesg % n_s
        # load the data of a single subject
        _hga, _anat, _beh = load_ss(n_s)
        _outc = _hga['trials'].data
        _ch = _hga['channels'].data
        
        # replace trial dimension with the model
        _hga = _hga.rename(trials=model)
        _hga[model] = list(_beh[col])
        
        # get which outcome to keep
        keep_outc = np.logical_or(_outc == outc[0], _outc == outc[1])
        _hga = _hga[keep_outc, ...]
        
        # replace with brain regions
        if space in ['parcels', 'roi']:
            _hga = _hga.rename(channels=space)
            _hga[space] = list(_anat['roi'])
            
            # take the mean of the hga per parcel
            if mean_roi:
                _hga = _hga.groupby(space).mean(space)
        elif prepend_suj_to_ch and (space == 'channels'):
            # prepend subject number to channel name
            _hga['channels'] = [f"suj{n_s}/{c}" for c in _ch]
        
        # ascontinuous array
        _hga.data = np.ascontiguousarray(_hga.data)
        
        hga.append(_hga)
        pbar.update_with_increment_value(1)

    return hga

---
# **1. Data loading**
## 1.1 Load the data of a single subject

In [None]:
hga = load_ms(s_range=6, model='outcome', condition='rew', space='channels')

## 1.2 Load the data of multiple subjects

In [None]:
hga = load_ms(s_range=[6, 10], model='outcome', condition='rew',
              space='channels')

## 1.3 Switch condition and model

In [None]:
hga = load_ms(s_range=6, model='pe', condition='pun', space='roi')

---
# **2. Measuring information**
## 2.1 Model-free analysis for a single subject

In [None]:
# define the DatasetEphy
hga = load_ms(s_range=6, model='outcome', condition='rew', space='channels')
ds = DatasetEphy(hga, y='outcome', roi='channels', times='times')

# measure information
mi, _ = WfMi(mi_type='cd', inference='ffx').fit(ds, mcp=None)

# plot the result
plt.figure(figsize=(10, 8))
mi.plot(x='times', hue='roi')
plt.axvline(0., color='k');

print("THE ANSWER IS CHANNEL O'7-O'6")

## 2.2 Model-free analysis across all of the subjects

In [None]:
# define the DatasetEphy
hga = load_ms(model='outcome', condition='context', space='roi')
ds = DatasetEphy(hga, y='outcome', roi='roi', times='times')

# measure information
mi, _ = WfMi(mi_type='cd', inference='ffx').fit(ds, mcp=None)

# plot the result
plt.figure(figsize=(10, 8))
mi.plot(x='times', hue='roi')
plt.axvline(0., color='k');

"""
The Lateral Orbital Frontal Cortex (lOFC) seems to be the brain region sharing
the most information with the contextual outcome. Said differently, we could
say that the lOFC is particulary involved in differentiating whether the
subject is going to experienced rewards or punishments
"""

## 2.3 Model-based analysis for a single subject

In [None]:
# define the DatasetEphy
hga = load_ms(s_range=6, model='rt', condition='rew', space='channels')
ds = DatasetEphy(hga, y='rt', roi='channels', times='times')

# measure information
mi, _ = WfMi(mi_type='cc', inference='ffx').fit(ds, mcp=None)

# plot the result
plt.figure(figsize=(10, 8))
mi.plot(x='times', hue='roi')
plt.axvline(0., color='k');

"""
The answer is channel O'7-O'6 !
"""

## 2.4 Model-based analysis for multiple subjects

In [None]:
# define the DatasetEphy
hga = load_ms(model='pe', condition='pun', space='roi')
ds = DatasetEphy(hga, y='pe', roi='roi', times='times')

# measure information
mi, _ = WfMi(mi_type='cc', inference='ffx').fit(ds, mcp=None)

# plot the result
plt.figure(figsize=(10, 8))
mi.plot(x='times', hue='roi')
plt.axvline(0., color='k')

"""
The anterior insula (aINS) is the region that carry the most information about
the prediction error during the punishment condition !
"""

## 2.5 Avoid taking the mean of neural activity

In [None]:
# define the DatasetEphy
hga = load_ms(model='rt', condition='rew', space='roi', mean_roi=False)
ds = DatasetEphy(hga, y='rt', roi='roi', times='times')

# measure information
mi, _ = WfMi(mi_type='cc', inference='ffx').fit(ds, mcp=None)

# plot the result
plt.figure(figsize=(10, 8))
mi.plot(x='times', hue='roi')
plt.axvline(0., color='k');