In [None]:
import os

import numpy as np
import scipy
import xarray as xr
import pandas as pd

from frites.conn import (conn_dfc, conn_covgc, conn_reshape_undirected,
                         conn_reshape_directed, define_windows, plot_windows,
                         conn_ravel_directed)

import matplotlib.pyplot as plt
from ipywidgets import interact, fixed

plt.style.use('seaborn-dark')
plt.style.use('seaborn-poster')

%load_ext autoreload
%autoreload 2

---
# **--- ROOT PATH ---**

<div class="alert alert-info"><p>

Define the path to where the data are located !
</p></div>

In [None]:
ROOT = '/run/media/etienne/DATA/Toolbox/BraiNets/CookingFrites/dataset/'

---
# **0 - Functions**

In [None]:
###############################################################################
###############################################################################
#                 Load the data of a single subject
###############################################################################
###############################################################################

def load_ss(subject_nb):
    """Load the data of a single subject.
    
    Parameters
    ----------
    subject_nb : int
        Subject number [0, 12]
    
    Returns
    -------
    hga : xarray.DataArray
        Xarray containing the high-gamma activity
    anat : pandas.DataFrame
        Table containing the anatomical informations
    beh : pandas.DataFrame
        Table containing the behavioral informations
    """
    # load the high-gamma activity
    file_hga = os.path.join(ROOT, 'hga', f'hga_s-{subject_nb}.nc')
    hga = xr.load_dataarray(file_hga)

    # load the name of the brain regions
    file_anat = os.path.join(ROOT, 'anat', f'anat_s-{subject_nb}.xlsx')
    anat = pd.read_excel(file_anat)

    # load the behavior
    file_beh = os.path.join(ROOT, 'beh', f'beh_s-{subject_nb}.xlsx')
    beh = pd.read_excel(file_beh)
    
    # simplify channel names (keep only the first contact name)
    hga['channels'] = [c.split('-')[0] for c in hga['channels'].data]
    
    return hga, anat, beh


###############################################################################
###############################################################################
#                 Load the data of multiple subjects
###############################################################################
###############################################################################

def load_ms(s_range=[0, 11], model='outcome', condition='rew',
            space='channels', mean_roi=True, prepend_suj_to_ch=True):
    """Load multiple subjects.
    
    Parameters
    ----------
    s_range : int or list
        Subjects to load. Use either an integer (e.g. 7) to load a single
        subject or a range of subjects (e.g. [5, 10])
    model : {'outcome', 'pe', 'rt'}
        Model to use. Use either :
        
            * 'outcome' : find differences in the neural activity between the
              outcomes
            * 'pe' : find regions with an activity correlating with the
              prediction error
            * 'rt' : find regions with an activity correlating with the
              reaction time
    condition : {'rew', 'pun', 'context', 'null'}
        Condition to load. Use either :
        
            * 'rew' : for outcomes {+0€; +1€}
            * 'pun' : for outcomes {-1€; -0€}
            * 'context' : for outcomes {-1€; +1€}
            * 'null' : for outcomes {-0€; +0€}
    space : {'channels', 'roi'}
        Specify if the spatial dimension should be described with channel names
        or with brain region names
    mean_roi : bool
        Specify if you want to take the mean high-gamma activity inside a brain
        region
    prepend_suj_to_ch : bool
        Add subject name to each channel name
    
    Returns
    -------
    hga : list
        List of high-gamma activity across subjects
    """
    # inputs checking
    if isinstance(s_range, int):
        s_range = [s_range, s_range]
    s_range[1] += 1
    s_range[0], s_range[1] = max(s_range[0], 0), min(s_range[1], 12)
    mesg = f"Subject %i | model={model} | condition={condition} | space={space}"
    pbar = ProgressBar(range(s_range[0], s_range[1]), mesg=mesg % 0)
    model = model.lower()
    assert space in ['channels', 'parcels', 'roi']
    
    # get the code of the condition
    outc = {
        'rew': (+1, +2),
        'pun': (-2, -1),
        'context': (-2, +2),
        'null': (-1, +1)
    }[condition]
    
    # get the behavioral column to use
    col = {
        'outcome': 'code',
        'pe': 'PE',
        'rt': 'RT'
    }[model]
    
    # load the data
    hga = []
    for n_s in range(s_range[0], s_range[1]):
        pbar._tqdm.desc = mesg % n_s
        # load the data of a single subject
        _hga, _anat, _beh = load_ss(n_s)
        _outc = _hga['trials'].data
        _ch = _hga['channels'].data
        
        # replace trial dimension with the model
        _hga = _hga.rename(trials=model)
        _hga[model] = list(_beh[col])
        
        # get which outcome to keep
        keep_outc = np.logical_or(_outc == outc[0], _outc == outc[1])
        _hga = _hga[keep_outc, ...]
        
        # replace with brain regions
        if space in ['parcels', 'roi']:
            _hga = _hga.rename(channels=space)
            _hga[space] = list(_anat['roi'])
            
            # take the mean of the hga per parcel
            if mean_roi:
                _hga = _hga.groupby(space).mean(space)
        elif prepend_suj_to_ch and (space == 'channels'):
            # prepend subject number to channel name
            _hga['channels'] = [f"suj{n_s}/{c}" for c in _ch]
        
        # ascontinuous array
        _hga.data = np.ascontiguousarray(_hga.data)
        
        hga.append(_hga)
        pbar.update_with_increment_value(1)

    return hga


###############################################################################
###############################################################################
#                           Plotting the results
###############################################################################
###############################################################################

def plot_conn(conn, figsize=(13, 10), cmap='Spectral_r', interactive=False):
    """Plot the connectivity array.
    
    Parameters
    ----------
    conn : xr.DataArray
        Output of a function to estimate the FC
    figsize : tuple
        Figure size
    cmap : string
        Colormap
    """
    # get if the connectivity array is directed or not
    if 'type' in conn.attrs.keys():
        directed = conn.attrs['type'] != 'dfc'
    else:
        directed = False
    
    # split between dynamic interactive or static
    if interactive:
        if directed:
            conn = conn_ravel_directed(conn.copy())

        @interact(roi=conn['roi'].data, demean=True, conn=fixed(conn))
        def plot(conn=None, roi=conn['roi'].data[0], demean=True):
            sub_times = conn['times'].data
            roi_idx = conn['roi'].data.tolist().index(roi)

            # compute confidence interval across trials
            confidence = 0.95
            n = len(conn['trials'])
            m = conn.mean('trials').data
            se = scipy.stats.sem(conn.data, axis=0)
            h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
            clow, chigh = m - h, m + h
            conn_m = conn.mean('trials')
            
            # data detrending
            if demean:
                cmin = conn_m.min('times', keepdims=True)
                conn_m.data -= cmin.data
                clow -= cmin.data
                chigh -= cmin.data
            else:
                pass

            cmin, cmax = clow.min(), chigh.max()
            dp = (cmax - cmin) / 20
            plt.plot(sub_times, conn_m.sel(roi=roi))
            plt.grid(True)
            plt.axvline(0., color='k', linestyle='--')
            plt.ylim(cmin - dp, cmax + dp)
            plt.xlim(sub_times[0], sub_times[-1])
            plt.fill_between(sub_times, clow[roi_idx, :], chigh[roi_idx, :],
                             alpha=.1)
    else:
        # average across the trial dimension
        if 'trials' in conn.dims:
            conn = conn.mean('trials')

        if len(conn['times']) == 1:
            if not directed:
                df = conn_reshape_undirected(conn).squeeze().to_pandas()
            else:
                df = conn_reshape_directed(conn).squeeze().to_pandas()
        else:
            df = conn.to_pandas()

        vmin = np.nanpercentile(df.values, 1)
        vmax = np.nanpercentile(df.values, 99)

        plt.figure(figsize=figsize)
        plt.imshow(df.values, cmap=cmap, vmin=vmin, vmax=vmax)
        ax = plt.gca()
        ax.set_xticks(np.arange(len(df.columns)))
        ax.set_yticks(np.arange(len(df.index)))
        ax.set_xticklabels(df.columns)
        ax.set_yticklabels(df.index)
        plt.colorbar()
        if directed:
            plt.xlabel('Targets'), plt.ylabel('Sources')


---
# **1. Undirected FC between channels**
## 1.1 Load and prepare the data

In [None]:
# load the data of subject 2
hga = load_ss(4)[0]

# select the temporal period between [0, 1.5]s
hga_s = hga.sel(times=slice(0., 1.5))

## 1.2 Compute the undirected FC

In [None]:
dfc = conn_dfc(hga_s, roi='channels', times='times')

## Plot the results

In [None]:
# plot the results
plot_conn(dfc)

"""
The pairs of contacts that are the more stronly connected is
X6-X7
"""

## 1.4 Mean DFC across trials

In [None]:
dfc_m = dfc.mean('trials')

## 1.5 Reshape the connectivity matrix

In [None]:
dfc_rsh = conn_reshape_undirected(dfc_m)

---
# **2. Undirected FC between brain regions**
## 2.1 Data preparation

In [None]:
# 1. load the data of subject 2
hga, anat, beh = load_ss(2)

# 2. rename the channel dimension
hga = hga.rename(channels='roi')

# fill this dimension with the name of the brain regions
hga['roi'] = list(anat['roi'])

## 2.2 Mean HGA inside brain regions

In [None]:
# groupby brain region and take the mean
hga_r = hga.groupby('roi').mean('roi')

"""
The variable `hga_r` has four brain regions (aINS, dlPFC, vmPFC and lOFC)
"""

## 2.3 Compute the FC between brain regions

In [None]:
# compute the static FC between brain regions
dfc_r = conn_dfc(hga_r.sel(times=slice(0, 1.5)), roi='roi', times='times')

"""
The output `dfc_r` has 6 pairs of brain regions
"""

## 2.4 Plot the results

In [None]:
# plot the connectivity matrix
plot_conn(dfc_r)

"""
The connection between the aINS and vmPFC has the strongest connection
"""

## 2.5 Define sliding windows

In [None]:
# 1. get the time vector
times = hga_r['times'].data

# 2. define sliding windows
ws, _ = define_windows(times, slwin_len=.3, slwin_step=.03)

# 3. plot the sliding windows
plot_windows(times, ws);

## 2.6 Compute the dynamic undirected FC

In [None]:
# compute the DFC
dfc_us = conn_dfc(hga_r, roi='roi', times='times', win_sample=ws)

# plot the result
plot_conn(dfc_us, interactive=True)

"""
The connection between the aINS and dlPFC has the strongest connectivity it the
maximum is arround 300ms
"""