In [None]:
import numpy as np
import torch
import pandas as pd
from function import * 
import time
import warnings

warnings.filterwarnings('ignore')

# ====================================================================
# 1. EXPERIMENT CONFIGURATION
# ====================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# parameters settings
P2 = [0.01 + 0.01 * i for i in range(10)]
Q2 = [0.001 + 0.001 * i for i in range(10)]
PI = [0.01 + 0.01 * i for i in range(10)]
SEEDS = [i * 10 for i in range(1, 11)]

# experiment scenarios (Balanced, Imbalanced)
GRAPH_TYPES = {
    'balanced': [100, 100, 100, 100, 100],
    'imbalanced': [150, 150, 50, 50, 50, 50]
}
WEIGHT_STRATEGIES = ['constant', 'linear'] # W_k=1, W_k=k

# hyperparameters
epochs = 15
lr = 0.4
total_runs = len(GRAPH_TYPES) * len(P2) * len(Q2) * len(PI) * len(WEIGHT_STRATEGIES) * len(SEEDS)

# ====================================================================
# 2. MAIN EXPERIMENT LOOP
# ====================================================================
all_results = []
run_id = 0

for graph_name, n_units in GRAPH_TYPES.items():
    for p_val in P2:
        for q_val in Q2:
            for pi_val in PI:
                for weight_strat in WEIGHT_STRATEGIES:
                    
                    seed_results = []
                    start_time_params = time.time()
                    
                    print("-" * 60)
                    print(f"{graph_name}, p={p_val:.2f}, q={q_val:.3f}, pi={pi_val:.2f}, w_k='{weight_strat}'")

                    for seed in SEEDS:
                        run_id += 1
                        print(f"  ({run_id}/{total_runs}) Seed: {seed}")
                        
                        # --- Graph generation and initialization ---
                        n_L = len(n_units)
                        prob_mat = np.full((n_L, n_L), q_val); np.fill_diagonal(prob_mat, p_val)
                        G, classes = generating_graph_with_simplices(n_units, prob_mat, seed)
                        g_sim, m_sim, am_sim = generate_simplex_sets(G, budget=500)
                        
                        true_labels = np.array(list(nx.get_node_attributes(G, 'label').values()))
                        init_prob, init_pred, x_known = initialization(G, classes, g_sim, pi_val)
                        init_metrics = calculate_all_metrics(true_labels, init_pred, init_prob, n_L)
                        
                        # --- training ---
                        n_max = len(g_sim)
                        coefs = precompute_hoi_coefficients(n_max, n_L, device)
                        
                        models_to_run = {"ALL": g_sim, "MAX": m_sim, "Aug-MAX": am_sim, "PI": [[] for _ in range(n_max)]}
                        if len(g_sim) > 1: models_to_run["PI"][1] = g_sim[1]

                        run_metrics = {'Initial_RW': init_metrics}

                        for name, simplices in models_to_run.items():
                            final_prob, final_pred = HOI_training(epochs, device, simplices, init_prob, x_known, lr, coefs, weight_strat)
                            run_metrics[name] = calculate_all_metrics(true_labels, final_pred, final_prob, n_L)
                        
                        seed_results.append(run_metrics)
                    
                    
                    df_seed = pd.DataFrame.from_dict({(i,j): seed_results[i][j] 
                                           for i in range(len(seed_results)) for j in seed_results[0].keys()}, orient='index')
                    agg_results = df_seed.groupby(level=1).agg(['mean', 'std']).fillna(0)
                    
                    
                    for model_name in agg_results.index:
                        for metric in ['accuracy', 'macro_f1', 'roc_auc', 'kappa']:
                            all_results.append({
                                'graph_type': graph_name, 'p': p_val, 'q': q_val, 'pi': pi_val,
                                'weight_strategy': weight_strat, 'model': model_name, 'metric': metric,
                                'mean': agg_results.loc[model_name][(metric, 'mean')],
                                'std': agg_results.loc[model_name][(metric, 'std')]
                            })

                    end_time_params = time.time()
                    print(f"  Time spent: {end_time_params - start_time_params:.2f} seconds")


# ====================================================================
# 3. SAVE FINAL RESULTS
# ====================================================================
print("\n" + "=" * 60)
final_df = pd.DataFrame(all_results)
final_df.to_csv('full_experiment_results.csv', index=False, encoding='utf-8-sig')
print(final_df.head())
print(final_df.tail())