In [None]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import scipy.stats

from src import settings
from src.utils import fileio


config = fileio.get_config(settings.CONFIG_NAME)

TREATMENTS = config["TREATMENTS"]

INPUT_PATH = os.path.join(settings.RESULTS_DIR, "local_measures")

all_treatments = fileio.load_multiple_folders(INPUT_PATH)
all_treatments = {key: value for key, value in all_treatments.items() if key in TREATMENTS}

OUTPUT_DIR = os.path.join(settings.REPORTS_DIR, "figures", "local_measures")
os.makedirs(OUTPUT_DIR, exist_ok=True)

treatment_dataframes = pd.DataFrame()
for treatment_name, treatment_path in all_treatments.items():
    all_groups = fileio.load_files_from_folder(treatment_path)

    for group_name, group_path in all_groups.items():
        df = pd.read_csv(group_path, index_col=0)
        df.rename_axis("Fly")
        df["Treatment"] = treatment_name
        df["Group"] = group_name.replace(".csv", "")
        treatment_dataframes = pd.concat([treatment_dataframes, df])

treatment_dataframes = treatment_dataframes.set_index("Treatment", append=True)
treatment_dataframes = treatment_dataframes.set_index("Group", append=True)
treatment_dataframes = treatment_dataframes.dropna(axis=1, how="all")
combined_data_reset = treatment_dataframes.reset_index()


# order = ["Cs_10D", "CsCh", "Cs_5DIZ", "LDA_5DIZ", "OCT_5DIZ", "LDA_OCT_5DIZ"]
for measure_name in treatment_dataframes.columns:
    if measure_name == "Treatment":
        continue

    if combined_data_reset[measure_name].min() == combined_data_reset[measure_name].max():
        continue

    fig, axes = plt.subplots(2, 2, figsize=(14, 11))
    plt.suptitle(f"Distribution of {measure_name}", fontsize=18)

    sns.pointplot(
        data=combined_data_reset,
        x="Treatment",
        y=measure_name,
        dodge=False,
        hue="Treatment",
        errorbar="sd",
        ax=axes[0, 0],
    )
    axes[0, 0].set_title("Plot using SD")
    axes[0, 0].set_xlabel("Treatment")
    axes[0, 0].set_ylabel(measure_name)
    axes[0, 0].tick_params(rotation=90)
    axes[0, 0].set_ylim(0, combined_data_reset[measure_name].max() * 1.1)

    sns.pointplot(
        data=combined_data_reset,
        x="Treatment",
        y=measure_name,
        dodge=False,
        hue="Treatment",
        errorbar="se",
        ax=axes[0, 1],
    )
    axes[0, 1].set_title("Plot using SE")
    axes[0, 1].set_xlabel("Treatment")
    # axes[0, 1].set_ylabel(measure_name)
    axes[0, 1].tick_params(rotation=90)
    axes[0, 1].set_ylim(0, combined_data_reset[measure_name].max() * 1.1)

    sns.boxplot(
        data=combined_data_reset,
        x="Treatment",
        y=measure_name,
        dodge=False,
        hue="Treatment",
        ax=axes[1, 0],
    )

    axes[1, 0].set_title(f"Boxplot")
    axes[1, 0].set_xlabel("Treatment")
    axes[1, 0].set_ylabel(measure_name)
    axes[1, 0].tick_params(rotation=90)
    axes[1, 0].set_ylim(0, combined_data_reset[measure_name].max() * 1.1)
    # axes[1, 0].legend("")

    sns.scatterplot(
        data=combined_data_reset,
        x="Group",
        y=measure_name,
        hue="Treatment",
        ax=axes[1, 1],
        s=50,
        alpha=0.6,
        markers=True,
        style="Treatment",
    )

    num_of_groups = combined_data_reset.Group.nunique()
    per_group = num_of_groups / len(all_treatments)
    locations_x = [per_group / 2 + (per_group * x) for x in range(0, len(all_treatments))]

    axes[1, 1].set_title(f"Scatter plot: {measure_name}")
    # axes[1, 1].legend(loc="center left", bbox_to_anchor=(1, 0.5), title="Treatment", labels=config["TREATMENTS"])
    axes[1, 1].set_xlabel("Treatment name")
    axes[1, 1].set_ylabel("Movement (mm/s)")
    axes[1, 1].set_xticks(locations_x)
    axes[1, 1].set_xticklabels(config["TREATMENTS"])
    axes[1, 1].tick_params(axis="x", rotation=90)
    axes[1, 1].set_ylim(0, combined_data_reset[measure_name].max() * 1.1)

    plt.tight_layout()

    save_path = os.path.join(OUTPUT_DIR, f"{measure_name}.png")
    plt.savefig(save_path)

    plt.show()