In [None]:
import json
import numpy as np
from scipy import stats
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import glob

def load_results(base_dir):
    results = {
        'gnn_baseline': [],
        'gat_ablation': [],
        'dagnn_main': []
    }

    json_files = glob.glob(f"{base_dir}/*.json")
    for filepath in json_files:
        with open(filepath, 'r') as f:
            data = json.load(f)
            model_name = data['model_name']
            if model_name in results:
                results[model_name].append(data)

    return results

def calculate_statistics(results):
    stats_dict = {}

    for model_type, runs in results.items():
        if not runs:
            continue

        metrics = {
            'f1_macro': [run['f1_macro'] for run in runs],
            'accuracy': [run['accuracy'] for run in runs]
        }

        stats_dict[model_type] = {
            metric: {
                'mean': np.mean(values),
                'std_err': stats.sem(values),
                'values': values
            }
            for metric, values in metrics.items()
        }

    return stats_dict

def perform_significance_tests(stats_dict):
    metrics = ['f1_macro', 'accuracy']
    matchups = [
        ('dagnn_main', 'gnn_baseline'),
        ('dagnn_main', 'gat_ablation'),
        ('gat_ablation', 'gnn_baseline')
    ]

    significance_results = []

    for metric in metrics:
        print(f"\nTesting metric: {metric}")
        for model1, model2 in matchups:
            print(f"Comparing {model1} vs {model2}")

            if model1 in stats_dict and model2 in stats_dict:
                values1 = stats_dict[model1][metric]['values']
                values2 = stats_dict[model2][metric]['values']

                if len(values1) > 0 and len(values2) > 0:
                    t_stat, p_value = stats.ttest_ind(values1, values2)

                    significance_results.append({
                        'metric': metric,
                        'model1': model1,
                        'model2': model2,
                        'p_value': p_value
                    })

    df = pd.DataFrame(significance_results)
    return df

def create_comparison_table(stats_dict):
    rows = []
    metrics = ['f1_macro', 'accuracy']

    for model in stats_dict.keys():
        row = {'Model': model}
        for metric in metrics:
            mean = stats_dict[model][metric]['mean']
            se = stats_dict[model][metric]['std_err']
            row[metric] = f"{mean:.4f} Â± {se:.4f}"
        rows.append(row)

    return pd.DataFrame(rows)

def plot_comparison(stats_dict, save_dir=None):
    metrics = ['f1_macro', 'accuracy']
    models = list(stats_dict.keys())

    fig, axes = plt.subplots(1, len(metrics), figsize=(10, 5))

    for i, metric in enumerate(metrics):
        data = []
        labels = []

        for model in models:
            values = stats_dict[model][metric]['values']
            data.extend(values)
            labels.extend([model] * len(values))

        sns.boxplot(x=labels, y=data, ax=axes[i])
        axes[i].set_title(metric)
        axes[i].set_xlabel('Model')
        axes[i].set_ylabel('Score')

    plt.tight_layout()

    if save_dir:
        plt.savefig(f"{save_dir}/model_comparison.png")
    plt.show()

def run_statistical_analysis(base_dir, save_dir=None):

    if save_dir is None:
        save_dir = base_dir

    results = load_results(base_dir)
    stats_dict = calculate_statistics(results)
    comparison_table = create_comparison_table(stats_dict)
    significance_tests = perform_significance_tests(stats_dict)
    plot_comparison(stats_dict, save_dir)

    if save_dir:
        comparison_table.to_csv(f"{save_dir}/model_comparison.csv")
        significance_tests.to_csv(f"{save_dir}/significance_tests.csv")

    return {
        'comparison': comparison_table,
        'significance_tests': significance_tests,
        'detailed_stats': stats_dict
    }

In [None]:
base_dir = '/content'
analysis_results = run_statistical_analysis(base_dir)