# Example that use the code implemented for FUCONE

In [9]:
import os.path as osp
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import gzip
import warnings

from sklearn.base import clone
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.ensemble import StackingClassifier
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import (
    LogisticRegression,
)
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC

from pyriemann.spatialfilters import CSP
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import FgMDM

from moabb.datasets import (
    Schirrmeister2017,  
)
from moabb.paradigms import LeftRightImagery


from fc_pipeline import (
    FunctionalTransformer,
    EnsureSPD,
    FC_DimRed,
    GetDataMemory,
)

warnings.filterwarnings(action='ignore', category=ConvergenceWarning)


In [17]:
if os.path.basename(os.getcwd()) == "FUCONE":
    os.chdir("Database")
basedir = os.getcwd()

datasets = Schirrmeister2017()


spectral_met = ["cov", "imcoh", "instantaneous"]
print(
    "#################" + "\n"
    "List of pre-selected FC metrics: " + "\n" + str(spectral_met) + "\n"
    "#################"
)
freqbands = {"defaultBand": [8, 35]}
print(
    "#################" + "\n"
    "List of pre-selected Frequency bands: " + "\n" + str(freqbands) + "\n"
    "#################"
)
# events = ["left_hand", "right_hand", "feet", "rest"]
events = ["right_hand", "feet"]
print(
    "#################" + "\n"
    "List of selected events: " + "\n" + str(events) + "\n"
    "#################"
)

threshold = [0.05]
percent_nodes = [10, 20, 30]

#################
List of pre-selected FC metrics: 
['cov', 'imcoh', 'instantaneous']
#################
#################
List of pre-selected Frequency bands: 
{'defaultBand': [8, 35]}
#################
#################
List of selected events: 
['right_hand', 'feet']
#################


In [None]:
## Baseline evaluations
bs_fmin, bs_fmax = 8, 35
ft = FunctionalTransformer(delta=1, ratio=0.5, method="cov", fmin=bs_fmin, fmax=bs_fmax)

step_mdm = [("fgmdm", FgMDM(metric="riemann", tsupdate=False))]
step_cov = [
    ("tg", TangentSpace(metric="riemann")),
    (
        "LogistReg",
        LogisticRegression(
            penalty="elasticnet", l1_ratio=0.15, intercept_scaling=1000.0, solver="saga"
        ),
    ),
]
step_fc = [
    ("tg", TangentSpace(metric="riemann")),
    (
        "LogistReg",
        LogisticRegression(
            penalty="elasticnet", l1_ratio=0.15, intercept_scaling=1000.0, solver="saga"
        ),
    ),
]


subj=[1]

freqbands = {
    #     "delta": [2, 4],
    #     "theta": [4, 8],
    #     "alpha": [8, 12],
    #     "beta": [15, 30],
    #     "gamma": [30, 45],
    "defaultBand": [8, 35],
}


spectral_met = [
    "coh",
    "imcoh"
]


step_fc = [
    ("tg", TangentSpace(metric="riemann")),
    (
        "LogistReg",
        LogisticRegression(
            penalty="elasticnet", l1_ratio=0.15, intercept_scaling=1000.0, solver="saga"
        ),
    ),
]

threshold = [0.05]
percent_nodes = [10]


results = list()
dataset_res = list()
for f in freqbands:
    subj = datasets.subject_list 
    subjects = subj
    for subject in tqdm(subjects, desc="subject"):
        fmin = freqbands[f][0]
        fmax = freqbands[f][1]
        paradigm = LeftRightImagery(fmin=fmin, fmax=fmax)
        ep_, _, _ = paradigm.get_data(
                dataset=datasets, subjects=[subj[0]], return_epochs=True
            )
        nchan = ep_.info["nchan"]
        nb_nodes = [int(p / 100.0 * nchan) for p in percent_nodes]

        ppl_DR, ppl_ens = {}, {}
        gd = GetData(paradigm, datasets, subject)
        for sm in spectral_met:
            ft = FunctionalTransformer(
                    delta=1, ratio=0.5, method=sm, fmin=fmin, fmax=fmax
             )
            
            if sm == "cov":
                ppl_DR["cov+elasticnet"] = Pipeline(
                    steps=[("gd", gd), ("sm", ft)] + step_cov
                )
            else:
                ft_DR = FC_DimRed(
                        threshold=threshold,
                        nb_nodes=nb_nodes, 
                        chan_names=ep_.info["ch_names"],
                        method=sm,
                        classifier=FgMDM(metric="riemann", tsupdate=False)
                    )
                pname_DR = sm + "+DR+elasticnet"
                ppl_DR[pname_DR] = Pipeline(
                        steps=[
                            ("gd", gd),
                            ("sm", ft),
                            ("spd", EnsureSPD()),
                            ("DR", ft_DR),
                        ]
                        + step_fc
                    )
                
       
            evaluation = WithinSessionEvaluationFCDR(
                        fmin=fmin,
                        fmax=fmax,
                        paradigm=paradigm,
                        datasets=[d],
                        n_jobs=-1,
                        random_state=42,
                        return_epochs=True,
                        overwrite=True,
                    )
            results = evaluation.process(ppl_DR[pname_DR])
            dataset_res.append(results)
            
dataset_res.to_csv("RedDim_Schirrmeister.csv") 

subject:   0%|                                                                                                               | 0/14 [00:00<?, ?it/s]Downloading data from 'https://web.gin.g-node.org/robintibor/high-gamma-dataset/raw/master/data/train/1.mat' to file '/Users/marieconstance.corsi/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/1.mat'.
SHA256 hash of downloaded file: 41dd2171d8806658e053a81e51960e1434f949615221afc4afb22adb46f4ceee
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://web.gin.g-node.org/robintibor/high-gamma-dataset/raw/master/data/test/1.mat' to file '/Users/marieconstance.corsi/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/test/1.mat'.


## References:
- papers connectivity?
- Paper ICASSP
- Paper TBME + link to the associated repo?