In [2]:
import sys
from pathlib import Path

def find_project_root(start_path: Path = Path.cwd(), marker: str = 'pyproject.toml') -> Path:
    current_path = start_path.resolve()
    for parent in [current_path] + list(current_path.parents):
        if (parent / marker).exists():
            return parent
        
def add_project_root_to_sys_path(marker: str = 'pyproject.toml'):
    project_root = find_project_root(marker=marker)
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

add_project_root_to_sys_path()


In [3]:
import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.asym_ensembles.data_loaders import load_dataset
from src.asym_ensembles.modeling.training import (
    set_global_seed,
    train_one_model,
    evaluate_model,
    evaluate_ensemble,
    average_pairwise_distance
)
from src.asym_ensembles.modeling.models import MLP, WMLP
import numpy as np
import copy
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [4]:
agg_data = pd.read_csv("../data/processed/agg3.csv")
agg_data.head()



Unnamed: 0,dataset_name,hidden_dim,repeat_index,metric_type,avg_dist_mlp,avg_dist_wmlp,mean_mlp_metric,std_mlp_metric,min_mlp_metric,max_mlp_metric,mean_wmlp_metric,std_wmlp_metric,min_wmlp_metric,max_wmlp_metric,ens_size,mlp_ens_metric,wmlp_ens_metric,mlp_interp_metric,wmlp_interp_metric
0,california,64,1,rmse,22.415121,36.387406,0.532812,0.007367,0.515546,0.550353,0.536074,0.008734,0.520336,0.569327,2,0.532046,0.520385,1.609295,2.417275
1,california,64,1,rmse,22.415121,36.387406,0.532812,0.007367,0.515546,0.550353,0.536074,0.008734,0.520336,0.569327,4,0.527157,0.51538,2.071703,2.642471
2,california,64,1,rmse,22.415121,36.387406,0.532812,0.007367,0.515546,0.550353,0.536074,0.008734,0.520336,0.569327,8,0.524177,0.51359,2.207988,2.750956
3,california,64,1,rmse,22.415121,36.387406,0.532812,0.007367,0.515546,0.550353,0.536074,0.008734,0.520336,0.569327,16,0.519491,0.513435,2.22534,2.639558
4,california,64,1,rmse,22.415121,36.387406,0.532812,0.007367,0.515546,0.550353,0.536074,0.008734,0.520336,0.569327,32,0.518588,0.513099,2.227415,2.30659


In [5]:
models_data = pd.read_csv("../data/processed/mod3.csv")
models_data.head()

Unnamed: 0,dataset_name,hidden_dim,repeat_index,model_index,metric_type,metric,masked_ratio,model_type,train_time,epochs_until_stop
0,california,64,1,1,rmse,0.535835,0.0,mlp,23.368304,117
1,california,64,1,2,rmse,0.536083,0.0,mlp,24.004768,124
2,california,64,1,3,rmse,0.536121,0.0,mlp,22.268934,117
3,california,64,1,4,rmse,0.533073,0.0,mlp,27.549265,145
4,california,64,1,5,rmse,0.526311,0.0,mlp,27.935298,145


In [6]:
import os
import matplotlib.pyplot as plt
import pandas as pd

def plot_and_save(df, dataset_name, hidden_dim, 
                  save_dir="../reports/figures", 
                  plot_wmlp_ensemble_metrics=True, 
                  plot_mlp_ensemble_metrics=True, 
                  plot_wmlp_interp_metrics=False,
                  plot_mlp_interp_metrics=False,
                  plot_mean_wmlp_metric=True, 
                  plot_mean_mlp_metric=True):
    subset = df[(df['dataset_name'] == dataset_name) & (df['hidden_dim'] == hidden_dim)]
    
    metric_type = subset['metric_type'].iloc[0]
    metric_label = "RMSE" if metric_type == "rmse" else "Accuracy"

    grouped = subset.groupby('ens_size').agg(
        mlp_ens_mean=('mlp_ens_metric', 'mean'),
        mlp_ens_std=('mlp_ens_metric', 'std'),
        wmlp_ens_mean=('wmlp_ens_metric', 'mean'),
        wmlp_ens_std=('wmlp_ens_metric', 'std'),
        mean_mlp_metric_mean=('mean_mlp_metric', 'mean'),
        mean_mlp_metric_std=('mean_mlp_metric', 'std'),
        mean_wmlp_metric_mean=('mean_wmlp_metric', 'mean'),
        mean_wmlp_metric_std=('mean_wmlp_metric', 'std'),
        mlp_interp_mean=('mlp_interp_metric', 'mean'),
        mlp_interp_std=('mlp_interp_metric', 'std'),
        wmlp_interp_mean=('wmlp_interp_metric', 'mean'),
        wmlp_interp_std=('wmlp_interp_metric', 'std')
    ).reset_index()
    
    grouped = grouped.sort_values('ens_size')
    
    filename_base = f"{dataset_name}_{hidden_dim}"
    
    if plot_mlp_ensemble_metrics:
        plt.plot(grouped['ens_size'], grouped['mlp_ens_mean'], label='MLP Ensemble', color='blue')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mlp_ens_mean'] - grouped['mlp_ens_std'],
            grouped['mlp_ens_mean'] + grouped['mlp_ens_std'],
            color='blue',
            alpha=0.2
        )

    if plot_wmlp_ensemble_metrics:    
        plt.plot(grouped['ens_size'], grouped['wmlp_ens_mean'], label='WMLP Ensemble', color='orange')
        plt.fill_between(
            grouped['ens_size'],
            grouped['wmlp_ens_mean'] - grouped['wmlp_ens_std'],
            grouped['wmlp_ens_mean'] + grouped['wmlp_ens_std'],
            color='orange',
            alpha=0.2
        )
    if plot_mlp_interp_metrics:
        plt.plot(grouped['ens_size'], grouped['mlp_interp_mean'], label='MLP Interpolated', color='yellow')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mlp_interp_mean'] - grouped['mlp_interp_std'],
            grouped['mlp_interp_mean'] + grouped['mlp_interp_std'],
            color='yellow',
            alpha=0.2
        )

    if plot_wmlp_interp_metrics:    
        plt.plot(grouped['ens_size'], grouped['wmlp_interp_mean'], label='WMLP Interpolated', color='black')
        plt.fill_between(
            grouped['ens_size'],
            grouped['wmlp_interp_mean'] - grouped['wmlp_interp_std'],
            grouped['wmlp_interp_mean'] + grouped['wmlp_interp_std'],
            color='black',
            alpha=0.2
        )

    if plot_mean_mlp_metric:
        plt.plot(grouped['ens_size'], grouped['mean_mlp_metric_mean'], label='Mean MLP 64 models Metric', color='green')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mean_mlp_metric_mean'] - grouped['mean_mlp_metric_std'],
            grouped['mean_mlp_metric_mean'] + grouped['mean_mlp_metric_std'],
            color='green',
            alpha=0.2
        )
         
    if plot_mean_wmlp_metric:
        plt.plot(grouped['ens_size'], grouped['mean_wmlp_metric_mean'], label='Mean WMLP 64 models Metric', color='red')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mean_wmlp_metric_mean'] - grouped['mean_wmlp_metric_std'],
            grouped['mean_wmlp_metric_mean'] + grouped['mean_wmlp_metric_std'],
            color='red',
            alpha=0.2
        )

    plt.title(f'Ensemble Performance on {dataset_name.capitalize()} with Hidden Dim {hidden_dim}')
    plt.xlabel('Ensemble Size')
    plt.ylabel(metric_label)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.xticks(grouped['ens_size'])
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{filename_base}_ensemble_metrics.png"))
    plt.close()
    
    if metric_type == 'rmse':
        mlp_improvement = (grouped['mean_mlp_metric_mean'] - grouped['mlp_ens_mean']) / grouped['mean_mlp_metric_mean'] * 100
        wmlp_improvement = (grouped['mean_wmlp_metric_mean'] - grouped['wmlp_ens_mean']) / grouped['mean_wmlp_metric_mean'] * 100
    else:
        mlp_improvement = (grouped['mlp_ens_mean'] - grouped['mean_mlp_metric_mean']) / grouped['mean_mlp_metric_mean'] * 100
        wmlp_improvement = (grouped['wmlp_ens_mean'] - grouped['mean_wmlp_metric_mean']) / grouped['mean_wmlp_metric_mean'] * 100

    plt.figure()
    plt.plot(grouped['ens_size'], mlp_improvement, marker='o', label='MLP Improvement (%)', color='blue')
    plt.plot(grouped['ens_size'], wmlp_improvement, marker='o', label='WMLP Improvement (%)', color='orange')
    plt.xlabel('Ensemble Size')
    plt.ylabel('Improvement Percentage (%)')
    plt.title(f'Improvement of Ensemble vs Mean for {dataset_name.capitalize()} with Hidden Dim {hidden_dim}')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.xticks(grouped['ens_size'])
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{filename_base}_ensemble_improvement.png"))
    plt.close()


improvement_table = []
datasets = agg_data['dataset_name'].unique()

for dataset in datasets:
    hidden_dims = agg_data[agg_data['dataset_name'] == dataset]['hidden_dim'].unique()
    for hidden_dim in hidden_dims:
        subset = agg_data[(agg_data['dataset_name'] == dataset) & (agg_data['hidden_dim'] == hidden_dim)]
        metric_type = subset['metric_type'].iloc[0]
        grouped = subset.groupby('ens_size').agg(
            mlp_ens_mean=('mlp_ens_metric', 'mean'),
            wmlp_ens_mean=('wmlp_ens_metric', 'mean'),
            mean_mlp_metric_mean=('mean_mlp_metric', 'mean'),
            mean_wmlp_metric_mean=('mean_wmlp_metric', 'mean')
        ).reset_index().sort_values('ens_size')
        
        row_dict = {"dataset": dataset, "hidden_dim": hidden_dim}
        for _, row in grouped.iterrows():
            ens_size = int(row['ens_size'])
            if metric_type == 'rmse':
                mlp_imp = (row['mean_mlp_metric_mean'] - row['mlp_ens_mean']) / row['mean_mlp_metric_mean'] * 100
                wmlp_imp = (row['mean_wmlp_metric_mean'] - row['wmlp_ens_mean']) / row['mean_wmlp_metric_mean'] * 100
            else:
                mlp_imp = (row['mlp_ens_mean'] - row['mean_mlp_metric_mean']) / row['mean_mlp_metric_mean'] * 100
                wmlp_imp = (row['wmlp_ens_mean'] - row['mean_wmlp_metric_mean']) / row['mean_wmlp_metric_mean'] * 100
            row_dict[f"ens_size_{ens_size}_mlp_improvement (%)"] = mlp_imp
            row_dict[f"ens_size_{ens_size}_wmlp_improvement (%)"] = wmlp_imp
        improvement_table.append(row_dict)

improvement_df = pd.DataFrame(improvement_table)
improvement_df.to_csv(os.path.join("../reports/figures", "ensemble_improvement_table.csv"), index=False)


for dataset in agg_data['dataset_name'].unique():
    hidden_dims = agg_data[agg_data['dataset_name'] == dataset]['hidden_dim'].unique()
    for hidden_dim in hidden_dims:
        plot_and_save(agg_data, dataset, hidden_dim, 
                      plot_wmlp_interp_metrics=False,
                      plot_mlp_interp_metrics=False)


In [16]:
import pandas as pd
import matplotlib.pyplot as plt
import os

def plot_moe_fix_total_hidden_size(csv_path, save_dir="../reports/figures/moe_fix_total_hidden_size"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    grouped = df.groupby(["dataset_name", "hidden_dim", "gating_type", "model_type", "num_experts"]).agg(
        mean_metric=("metric", "mean"),
        std_metric=("metric", "std")
    ).reset_index().sort_values(["dataset_name", "hidden_dim", "gating_type", "num_experts"])

    keys = grouped[["dataset_name", "gating_type"]].drop_duplicates()

    for _, row in keys.iterrows():
        dataset = row["dataset_name"]
        gating = row["gating_type"]
        subset = grouped[
            (grouped["dataset_name"] == dataset) &
            (grouped["gating_type"] == gating)
        ]
        model_types = subset["model_type"].unique()
        plt.figure(figsize=(8, 5))
        colors = {
            "mlp": "blue",
            "wmlp": "orange",
            "imlp": "green",
            "iwmlp": "red"
        }
        for m_type in model_types:
            sub = subset[subset["model_type"] == m_type]
            plt.plot(sub["num_experts"], sub["mean_metric"], label=m_type, color=colors.get(m_type, None))
            plt.fill_between(
                sub["num_experts"],
                sub["mean_metric"] - sub["std_metric"],
                sub["mean_metric"] + sub["std_metric"],
                color=colors.get(m_type, None),
                alpha=0.2
            )
        plt.title(f"{dataset} Total Hidden Dim 1024, Gating {gating}")
        plt.xlabel("Number of Experts")
        plt.ylabel("Metric (mean ± std)")
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend()
        plt.tight_layout()
        filename = f"{dataset}_{hidden_dim}_{gating}_moe_fix_total_hidden_size.png"
        plt.savefig(os.path.join(save_dir, filename))
        plt.close()

def plot_moe_num_experts(csv_path, save_dir="../reports/figures/moe_num_experts"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)

    grouped = df.groupby(["dataset_name", "hidden_dim", "gating_type", "model_type", "metric_type", "num_experts"]).agg(
        mean_metric=("metric", "mean"),
        std_metric=("metric", "std")
    ).reset_index().sort_values(["dataset_name", "hidden_dim", "gating_type", "num_experts", "metric_type"])
    
    improvement_table_rows = []
    
    keys = grouped[["dataset_name", "hidden_dim", "gating_type", "metric_type"]].drop_duplicates()
    
    colors = {
        "mlp": "blue",
        "wmlp": "orange",
        "imlp": "green",
        "iwmlp": "red"
    }
    
    for _, row in keys.iterrows():
        dataset = row["dataset_name"]
        hidden_dim = row["hidden_dim"]
        gating = row["gating_type"]
        metric_type = row["metric_type"]
        
        subset = grouped[
            (grouped["dataset_name"] == dataset) &
            (grouped["hidden_dim"] == hidden_dim) &
            (grouped["gating_type"] == gating)
        ]
        model_types = subset["model_type"].unique()
        
        plt.figure(figsize=(8, 5))
        for m_type in model_types:
            sub = subset[subset["model_type"] == m_type]
            plt.plot(sub["num_experts"], sub["mean_metric"], label=m_type, color=colors.get(m_type, None))
            plt.fill_between(
                sub["num_experts"],
                sub["mean_metric"] - sub["std_metric"],
                sub["mean_metric"] + sub["std_metric"],
                color=colors.get(m_type, None),
                alpha=0.2
            )
        plt.title(f"{dataset} (Hidden Dim: {hidden_dim}, Gating: {gating})")
        plt.xlabel("num_experts")
        plt.xticks(subset['num_experts'])
        plt.ylabel("Metric (mean ± std)")
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend()
        plt.tight_layout()
        filename = f"{dataset}_{hidden_dim}_{gating}_moe_num_experts.png"
        plt.savefig(os.path.join(save_dir, filename))
        plt.close()
        
        plt.figure(figsize=(8, 5))
        for m_type in model_types:
            sub = subset[subset["model_type"] == m_type].sort_values("num_experts").copy()
            
            baseline_rows = sub[sub["num_experts"] == 2]
            if baseline_rows.empty:
                continue
            baseline = baseline_rows["mean_metric"].iloc[0]
            
            if metric_type == 'rmse':
                sub["improvement_pct"] = (baseline - sub["mean_metric"]) / baseline * 100
            else:
                sub["improvement_pct"] = (sub["mean_metric"] - baseline) / baseline * 100
            
            plt.plot(sub["num_experts"], sub["improvement_pct"], label=m_type, 
                     color=colors.get(m_type, None), marker='o')
            
            imp_dict = {
                "dataset_name": dataset,
                "hidden_dim": hidden_dim,
                "gating_type": gating,
                "model_type": m_type
            }
            for _, r in sub.iterrows():
                imp_dict[f"impr_{int(r['num_experts'])}"] = r["improvement_pct"]
            improvement_table_rows.append(imp_dict)
            
        plt.title(f"Relative improvement over the 2 experts \n{dataset} (Hidden Dim: {hidden_dim}, Gating: {gating})")
        plt.xlabel("num_experts")
        plt.xticks(subset['num_experts'])
        plt.ylabel("% Improvement")
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend()
        plt.tight_layout()
        filename_impr = f"{dataset}_{hidden_dim}_{gating}_moe_num_experts_improvement.png"
        plt.savefig(os.path.join(save_dir, filename_impr))
        plt.close()
    
    improvement_table = pd.DataFrame(improvement_table_rows)
    output_csv = os.path.join(save_dir, "improvement_table.csv")
    improvement_table.to_csv(output_csv, index=False)

In [17]:
# csv_fix_total_hidden_size = "../data/processed/moe_fix_total_hidden_size.csv"
# csv_num_experts = "../data/processed/moe_num_experts.csv"



In [18]:
# data/processed/table3_gumbel_standart_64fix.csv
plot_moe_num_experts("../data/processed/table3_gumbel_standart_64fix.csv")

In [43]:
plot_moe_fix_total_hidden_size("../data/processed/table3_gumbel_1024_div_num_exp.csv")
plot_moe_fix_total_hidden_size("../data/processed/table3_standart_1024_div_num_exp.csv")
plot_moe_num_experts("../data/processed/table3_gumbel_64fix.csv")
plot_moe_num_experts("../data/processed/table3_standart_64fix.csv")