# **3. Compute the Dynamic Functional Connectivity**
---

This notebook illustrates how to estimate the Dynamic Functional Connectivity (DFC) estimated at the single-trial level for each participant.

In [None]:
import os
import argparse

import numpy as np
import xarray as xr

from frites.utils import savgol_filter
from frites.conn import conn_dfc, define_windows
from frites.estimator import GCMIEstimator, CorrEstimator, DcorrEstimator


# **Global variables**
---

In [None]:
###############################################################################
# Alignement settings
# -------------------

# Chose to extract data align on :
# - Stimulus presentation : "sample_stim"
# - Subject's response    : "sample_resp"
reference = 'sample_stim'


# Frequency settings
# ------------------

# Frequency to consider :
# - Gamma : "f50f150"
# - Beta  : "f8f24"
freq = "f50f150"

# amount of temporal smoothing
smoothing = "sm0"


# DFC settings
# ------------

# Sliding window settings
wlen = .5    # length of the window
wstep = .01  # amount of overlap between consecituve windows

# Estimator of information
# 'gcmi'  : Gaussian-Copula Mutual Information
# 'corr'  : Pearson correlation
# 'dcorr' : Distance correlation
estimator = 'gcmi'

# temporal settings
timings_ref = {
    'sample_stim': slice(-1., 2.),  # 0, 2.
    'sample_resp': slice(-2., 1.),
}

savgol = 1.

# Folder settings
# ---------------

# root folder
root = '/hpc/brainets/data/db_ebrains/seeg'

# define where the data are located
from_folder = f'{root}/epochs/{reference}/{freq}-{smoothing}/data'

# define where to save the dfc
to_folder = f'{root}/conn/dfc-dyn/dfc-%s/st' % estimator

# define how to save the file
sav_str = int(savgol) if isinstance(savgol, (int, float)) else savgol
_save_as = (f"dfc_est-{estimator}_{reference}-{freq}-{smoothing}_"
            f"savgol-{sav_str}_%s.nc")
save_as = os.path.join(to_folder, _save_as)
###############################################################################

# get the list of data files
files = os.listdir(from_folder)

# get the temporal vector
times = xr.load_dataarray(files[0])['times'].data

# define windows
timings = timings_ref[reference]
win, _ = define_windows(
    times, slwin_start=timings.start, slwin_stop=timings.stop,
    slwin_len=wlen, slwin_step=wstep)

# build the estimator
if estimator == 'gcmi':
    est = GCMIEstimator(mi_type='cc', copnorm=False, biascorrect=False)
elif estimator == 'corr':
    est = CorrEstimator()
elif estimator == 'dcorr':
    est = DcorrEstimator(implementation='frites')


# **Compute Functional Connectivity**
---

## Low-level function or computing FC 

In [None]:
def compute_dfc(s, from_folder, save_as, estimator, win, savgol, n_jobs):
    """Compute the DFC of a single subject, in a single thread."""

    # -------------------------------------------------------------------------
    # skip if already computed
    if os.path.isfile(save_as_mean % s):
        print(f'---- SUBJECT {s} SKIPPED BECAUSE ALREADY COMPUTED ----')
        return None

    # -------------------------------------------------------------------------
    # load the DataArray
    f = st.search(s, folder=from_folder, verbose=True)
    assert len(f) == 1
    _da = xr.open_dataarray(f[0]).astype(np.float32)
    trials = _da['trials']
    sfreq = _da.attrs['sfreq']

    # if needed, smooth the data
    if isinstance(savgol, (int, float)):
        _da = savgol_filter(_da, savgol, axis='times', sfreq=sfreq)

    # reset the dataarray (didn't found a beter solution...)
    _tr = np.arange(len(trials))
    da = xr.DataArray(
        _da.data, dims=('trials', 'roi', 'times'),
        coords=(_tr, _da['roi'].data.astype(str), _da['times']))

    # drop bad roi
    keep = ['parcel' not in r for r in da['roi'].data]
    da = da.sel(roi=keep)

    # compute dfc
    dfc = conn_dfc(da, times='times', roi='roi', estimator=estimator,
                   n_jobs=n_jobs, win_sample=win)
    dfc = dfc.astype(np.float32)
    dfc['trials'] = trials

    dfc.to_netcdf(save_as % s)


## Compute the single-subject, single-trial DFC

In [None]:
for f in files:
    suj = f.split('_')[0]
    compute_dfc( suj, from_folder, save_as, est, win, savgol, -1)
