In [None]:
import numpy as np
import pandas as pd

from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from fau_colors import register_cmaps

register_cmaps()

sns.set_theme(context="talk", style="white", palette="faculties", font_scale=1.2)

%matplotlib widget

In [None]:
from sleep_analysis.datasets.mesadataset import MesaDataset

In [None]:
dataset = MesaDataset()

In [None]:
dataset[5].ground_truth

In [None]:
res_path = Path.cwd().parents[1].joinpath("exports", "results_per_algorithm")

In [None]:
dict_alg_type = {}
for alg_type in ["ML", "DL"]:
    dict_alg = {}
    if alg_type == "ML":
        for alg, alg_name in [
            ("rf", "Random Forest"),
            ("mlp", "MLP"),
            ("svm", "SVM"),
            ("adaboost", "AdaBoost"),
            ("xgb", "XGBoost"),
        ]:
            dict_modality = {}

            for modality, modality_name in [
                ("acc", "ACT"),
                ("acc_hrv", "ACT + HRV"),
                ("acc_hrv_rrv", "ACT + HRV + RRV"),
                ("acc_hrv_edr", "ACT + HRV + EDR"),
            ]:

                dict_stage = {}

                for stage, stage_name in [("binary", "Binary"), ("3stage", "3stage"), ("5stage", "5stage")]:

                    df = pd.read_csv(
                        res_path.joinpath(alg + "/" + alg + "_" + "benchmark" + "_" + modality + "_" + stage + ".csv"),
                        header=[0],
                        index_col=[0],
                    )
                    df.sort_index(axis=1, level="subject", inplace=True)
                    df.index.name = "metric"
                    df.columns.name = "subject"
                    df = df.drop("confusion_matrix")
                    df = df.stack(level="subject")

                    # display(list(dict.fromkeys(list(df.index.get_level_values(1)))))
                    dict_stage[stage_name] = pd.DataFrame(df, columns=["data"])
                dict_modality[modality_name] = pd.concat(dict_stage, names=["stage"])
            dict_alg[alg_name] = pd.concat(dict_modality, names=["modality"])
        dict_alg_type[alg_type] = pd.concat(dict_alg, names=["algorithm"])

    if alg_type == "DL":
        for alg, alg_name in [("LSTM", "LSTM"), ("TCN", "TCN")]:
            dict_modality = {}
            for modality, modality_name in [
                ("acc", "ACT"),
                ("acc_hrv", "ACT + HRV"),
                ("acc_hrv_rrv", "ACT + HRV + RRV"),
                ("acc_hrv_edr", "ACT + HRV + EDR"),
            ]:

                dict_stage = {}

                for stage, stage_name in [("binary", "Binary"), ("3stage", "3stage"), ("5stage", "5stage")]:

                    df = pd.read_csv(
                        res_path.joinpath(alg + "/" + alg + "_" + "benchmark" + "_" + modality + "_" + stage + ".csv"),
                        header=[0],
                        index_col=[0],
                    )

                    df.sort_index(axis=1, level="subject", inplace=True)
                    df.index.name = "metric"
                    df.columns.name = "subject"
                    df = df.drop("confusion_matrix")
                    df = df.stack(level="subject")

                    dict_stage[stage_name] = pd.DataFrame(df, columns=["data"])

                dict_modality[modality_name] = pd.concat(dict_stage, names=["stage"])
            dict_alg[alg] = pd.concat(dict_modality, names=["modality"])

        dict_alg_type[alg_type] = pd.concat(dict_alg, names=["algorithm"])
df = pd.concat(dict_alg_type, names=["algorithm type"])

In [None]:
df = df.rename(index={"specifity": "specificity"})
df = df.unstack("metric").astype(float)
df.loc[:, ("data", ["accuracy", "recall", "precision", "specificity", "f1"])] *= 100
df = df.stack("metric")

In [None]:
df

In [None]:
test_idx_list = [
    "0027",
    "0077",
    "0111",
    "0169",
    "0193",
    "0197",
    "0204",
    "0269",
    "0306",
    "0372",
    "0388",
    "0393",
    "0408",
    "0474",
    "0526",
    "0548",
    "0586",
    "0599",
    "0672",
    "0683",
    "0807",
    "0856",
    "0889",
    "0921",
    "0923",
    "0934",
    "0935",
    "0962",
    "0967",
    "0968",
    "1080",
    "1113",
    "1164",
    "1187",
    "1209",
    "1294",
    "1297",
    "1308",
    "1395",
    "1453",
    "1474",
    "1476",
    "1497",
    "1502",
    "1552",
    "1563",
    "1570",
    "1584",
    "1589",
    "1620",
    "1672",
    "1677",
    "1704",
    "1707",
    "1735",
    "1766",
    "1768",
    "1797",
    "1821",
    "1844",
    "1856",
    "1874",
    "1878",
    "1884",
    "1921",
    "1964",
    "2003",
    "2043",
    "2119",
    "2139",
    "2145",
    "2163",
    "2193",
    "2251",
    "2279",
    "2372",
    "2388",
    "2397",
    "2429",
    "2464",
    "2519",
    "2604",
    "2614",
    "2659",
    "2685",
    "2701",
    "2738",
    "2762",
    "2780",
    "2820",
    "2834",
    "2913",
    "2930",
    "2952",
    "2987",
    "2988",
    "2995",
    "3003",
    "3006",
    "3028",
    "3053",
    "3066",
    "3094",
    "3104",
    "3112",
    "3224",
    "3297",
    "3317",
    "3337",
    "3344",
    "3352",
    "3375",
    "3415",
    "3423",
    "3486",
    "3516",
    "3520",
    "3529",
    "3537",
    "3622",
    "3630",
    "3634",
    "3656",
    "3664",
    "3690",
    "3717",
    "3745",
    "3760",
    "3793",
    "3795",
    "3803",
    "3855",
    "3892",
    "3971",
    "3974",
    "3976",
    "4017",
    "4128",
    "4190",
    "4199",
    "4240",
    "4277",
    "4301",
    "4330",
    "4334",
    "4379",
    "4394",
    "4480",
    "4488",
    "4500",
    "4515",
    "4541",
    "4563",
    "4580",
    "4592",
    "4641",
    "4648",
    "4677",
    "4723",
    "4729",
    "4777",
    "4826",
    "4888",
    "4980",
    "5002",
    "5006",
    "5009",
    "5096",
    "5103",
    "5104",
    "5131",
    "5167",
    "5261",
    "5292",
    "5298",
    "5304",
    "5318",
    "5351",
    "5362",
    "5393",
    "5427",
    "5440",
    "5532",
    "5550",
    "5608",
    "5656",
    "5680",
    "5722",
    "5784",
    "5792",
    "5847",
    "5882",
    "5888",
    "5896",
    "5906",
    "6000",
    "6009",
    "6027",
    "6029",
    "6050",
    "6115",
    "6205",
    "6262",
    "6274",
    "6280",
    "6291",
    "6292",
    "6298",
    "6306",
    "6311",
    "6333",
    "6384",
    "6460",
    "6462",
    "6501",
    "6509",
    "6566",
    "6610",
    "6632",
    "6671",
    "6697",
    "6726",
    "6784",
    "6807",
]
with open("test_idx.pkl", "wb") as f:
    pickle.dump(test_idx_list, f)

In [None]:
dataset = MesaDataset()
dataset.get_subset(mesa_id=test_idx_list)

In [None]:
# test_set_info = [subj.information for subj in dataset]

In [None]:
# df_test_info = pd.concat(test_set_info, axis=1)
# df_test_info = df_test_info.T
# df_test_info.columns.name = "info"
# df_test_info = df_test_info.rename(columns={'race1c':'race',
#                                 "gender1":"gender",
#                                 "sleepage5c":"age",
#                                 "overall5":"PSG_quality",
#                                 "whiirs5c": "WHIIRS_score",
#                                 "slpapnea5":"sleep_apnea",
#                                 "insmnia5":"insomnia",
#                                 "rstlesslgs5":"resstles_legs",
#                                 "actquality5":"quality_actigraphy",
#                                 "ahi_a0h4": "AH-Index",
#                                 "extrahrs5":"extra_work_hours"})
# df_test_info.index.name = "subject"
# df_test_info.to_pickle("mesa_test_info.pkl")
# df_test_info

In [None]:
df_test_info = pd.read_pickle("mesa_test_info.pkl")

In [None]:
df_study = df.join(df_test_info, on=["subject"])

In [None]:
df_study["gender"].replace(to_replace=[0, 1], value=["female", "male"], inplace=True)
df_study["race"].replace(to_replace=[1, 2, 3, 4], value=["White", "Asian", "Afro-american", "Hispanic"], inplace=True)

df_study

In [None]:
df_study.to_pickle("full_df.pkl")