# Classification - TSST vs. fTSST

## Imports and Helper Functions

In [None]:
import re
import json
import warnings

from pathlib import Path

import pandas as pd
import numpy as np
import pingouin as pg

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import GroupKFold
from tqdm.auto import tqdm

from fau_colors import cmaps

import biopsykit as bp
from biopsykit.stats import StatsPipeline
from biopsykit.classification.model_selection import SklearnPipelinePermuter
from biopsykit.classification.utils import prepare_df_sklearn

from empkins_d03_macro_analysis.classification.utils import flatten_wide_format_column_names
from empkins_d03_macro_analysis.classification.hyperparameter_search.macro_prestudy import (
    get_model_dict,
    get_hyper_para_dict,
    get_hyper_search_dict,
)

%load_ext autoreload
%autoreload 2
%matplotlib widget

In [None]:
plt.close("all")

palette = sns.color_palette(cmaps.faculties)
sns.set_theme(context="notebook", style="ticks", palette=palette)

plt.rcParams["figure.figsize"] = (8, 4)
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["mathtext.default"] = "regular"

palette

## Setup Paths

In [None]:
input_path = Path("../../../00_general")
output_path = Path("../../output/classification")
feature_path = input_path.joinpath("feature_export/motion_features")
output_path.mkdir(exist_ok=True)

## Load Features and Prepare DataFrame for Sklearn

In [None]:
feature_path = feature_path.joinpath("motion_features_cleaned.csv")
data = bp.io.load_long_format_csv(feature_path)

In [None]:
levels_unstack = list(data.index.names)
for level in ["subject", "condition"]:
    levels_unstack.remove(level)
data_wide = data["data"].unstack(levels_unstack)

data_wide = flatten_wide_format_column_names(data_wide)
data_wide.head()

In [None]:
X, y, groups, group_keys = prepare_df_sklearn(data_wide, label_col="condition", print_summary=True)
# X = data_wide.copy()

## Specify Estimator Combinations and Parameters for Hyperparameter Search

In [None]:
model_dict = get_model_dict()
params_dict = get_hyper_para_dict(num_subjects=len(group_keys))
hyper_search_dict = get_hyper_search_dict()

In [None]:
model_dict

## Setup PipelinePermuter and Cross-Validations for Model Evaluation

In [None]:
pipeline_permuter = SklearnPipelinePermuter(model_dict, params_dict, hyper_search_dict, random_state=0)

Fit all pipelines

In [None]:
outer_cv = GroupKFold(5)
inner_cv = GroupKFold(5)

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=ConvergenceWarning)
pipeline_permuter.fit(X, y, outer_cv=outer_cv, inner_cv=inner_cv, groups=groups)

In [None]:
data.to_csv(output_path.joinpath(f"{feature_path.stem}.csv"))
pipeline_permuter.to_pickle(output_path.joinpath(f"classification_general_pipeline_permuter.pkl"))