In [None]:
import os
import toml
import scipy

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import scipy.stats
from statsmodels.stats.multicomp import pairwise_tukeyhsd

from src import settings
from src.utils import fileio

CONFIG_PATH = os.path.join(settings.CONFIG_DIR, 'main.toml')
with open(CONFIG_PATH, 'r') as file:
    config = toml.load(file)  

INPUT_PATH = os.path.join(settings.RESULTS_DIR, 'global_measures')
all_treatments = fileio.load_files_from_folder(INPUT_PATH)

dataframes = []
for treatment_name, treatment_path in all_treatments.items():
    treatment_name = treatment_name.replace('.csv', '')
    if treatment_name in config['TREATMENTS']:
        df = pd.read_csv(treatment_path, index_col=0)
        df['Treatment'] = treatment_name
        dataframes.append(df)

combined_data = pd.concat(dataframes)
combined_data_reset = combined_data.reset_index()

TREATMENTS = config['TREATMENTS']

for i, measure_name in enumerate(combined_data.columns.tolist()):
    if measure_name == 'Treatment':
        continue
    
    treatment_sums = {}
    for treatment in TREATMENTS:
        treatment_sums[f'sum_{treatment}'] = combined_data_reset[combined_data_reset['Treatment'] == treatment][measure_name]

    anova_result = scipy.stats.f_oneway(*treatment_sums.values())

    all_data = np.concatenate([*treatment_sums.values()])
    group_labels = []
    for treatment in TREATMENTS:
        group_labels.extend([treatment] * len(combined_data_reset[combined_data_reset['Treatment'] == treatment][measure_name]))

    tukey_results = pairwise_tukeyhsd(all_data, group_labels)

    if any(tukey_results.reject):
        print('='*90)
        print(measure_name)
        print(tukey_results)

        plt.figure(figsize=(6, 4))
        order = TREATMENTS
        sns.pointplot(data=combined_data_reset, x='Treatment', y=measure_name, hue='Treatment', errorbar='sd', order=order)
        plt.xlabel('Treatment')
        plt.ylabel(measure_name)
        plt.title(f'Distribution of {measure_name}', fontsize=18)
        plt.legend('')
        plt.show() 
        print('='*90)
        
        # sns.boxplot(data=combined_data_reset, x='Treatment', y=measure_name, hue='index', order=order)
        # plt.xlabel('Treatment')
        # plt.ylabel(measure_name)
        # plt.title(f'Distribution of {measure_name}', fontsize=18)
        # plt.legend('')
        # plt.show() 
        # print('='*90)