In [9]:
import sys
from pathlib import Path

def find_project_root(start_path: Path = Path.cwd(), marker: str = 'pyproject.toml') -> Path:
    current_path = start_path.resolve()
    for parent in [current_path] + list(current_path.parents):
        if (parent / marker).exists():
            return parent
        
def add_project_root_to_sys_path(marker: str = 'pyproject.toml'):
    project_root = find_project_root(marker=marker)
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

add_project_root_to_sys_path()


In [10]:
import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.asym_ensembles.data_loaders import load_dataset
from src.asym_ensembles.modeling.training import (
    set_global_seed,
    train_one_model,
    evaluate_model,
    evaluate_ensemble,
    average_pairwise_distance
)
from src.asym_ensembles.modeling.models import MLP, WMLP
import numpy as np
import copy
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [3]:
agg_data = pd.read_csv("../data/processed/agg1.csv")
agg_data.head()



Unnamed: 0,dataset_name,hidden_dim,repeat_index,metric_type,avg_dist_mlp,avg_dist_wmlp,mean_mlp_metric,std_mlp_metric,min_mlp_metric,max_mlp_metric,mean_wmlp_metric,std_wmlp_metric,min_wmlp_metric,max_wmlp_metric,ens_size,mlp_ens_metric,wmlp_ens_metric
0,california,64,1,rmse,22.415121,36.568644,0.532812,0.007367,0.515546,0.550353,0.538742,0.014441,0.52436,0.629947,2,0.532046,0.517073
1,california,64,1,rmse,22.415121,36.568644,0.532812,0.007367,0.515546,0.550353,0.538742,0.014441,0.52436,0.629947,4,0.527157,0.515449
2,california,64,1,rmse,22.415121,36.568644,0.532812,0.007367,0.515546,0.550353,0.538742,0.014441,0.52436,0.629947,8,0.524177,0.513151
3,california,64,1,rmse,22.415121,36.568644,0.532812,0.007367,0.515546,0.550353,0.538742,0.014441,0.52436,0.629947,16,0.519491,0.513767
4,california,64,1,rmse,22.415121,36.568644,0.532812,0.007367,0.515546,0.550353,0.538742,0.014441,0.52436,0.629947,32,0.518588,0.513353


In [4]:
models_data = pd.read_csv("../data/processed/mod1.csv")
models_data.head()

Unnamed: 0,dataset_name,hidden_dim,repeat_index,model_index,metric_type,metric,masked_ratio,model_type,train_time,epochs_until_stop
0,california,64,1,1,rmse,0.535835,0.0,mlp,83.632415,117
1,california,64,1,2,rmse,0.536083,0.0,mlp,85.995286,124
2,california,64,1,3,rmse,0.536121,0.0,mlp,81.994559,117
3,california,64,1,4,rmse,0.533073,0.0,mlp,95.108512,145
4,california,64,1,5,rmse,0.526311,0.0,mlp,94.028229,145


In [5]:
def plot_and_save(df, dataset_name, hidden_dim, 
                  save_dir="../reports/figures", 
                  plot_wmlp_ensemble_metrics=True, 
                  plot_mlp_ensemble_metrics=True, 
                  plot_mean_wmlp_metric=True, 
                  plot_mean_mlp_metric=True):
    subset = df[(df['dataset_name'] == dataset_name) & (df['hidden_dim'] == hidden_dim)]
    
    if subset.empty:
        print(f"Нет данных для датасета {dataset_name} и hidden_dim {hidden_dim}")
        return
    
    metric_type = subset['metric_type'].iloc[0]
    metric_label = "RMSE" if metric_type == "rmse" else "Accuracy"
    
    grouped = subset.groupby('ens_size').agg(
        mlp_ens_mean=('mlp_ens_metric', 'mean'),
        mlp_ens_std=('mlp_ens_metric', 'std'),
        wmlp_ens_mean=('wmlp_ens_metric', 'mean'),
        wmlp_ens_std=('wmlp_ens_metric', 'std'),
        mean_mlp_metric_mean=('mean_mlp_metric', 'mean'),
        mean_mlp_metric_std=('mean_mlp_metric', 'std'),
        mean_wmlp_metric_mean=('mean_wmlp_metric', 'mean'),
        mean_wmlp_metric_std=('mean_wmlp_metric', 'std'),
    ).reset_index()
    
    grouped = grouped.sort_values('ens_size')
    
    filename_base = f"{dataset_name}_{hidden_dim}"
    
    if plot_mlp_ensemble_metrics:
        plt.figure(figsize=(10, 6))
        plt.plot(grouped['ens_size'], grouped['mlp_ens_mean'], label='MLP Ensemble', color='blue')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mlp_ens_mean'] - grouped['mlp_ens_std'],
            grouped['mlp_ens_mean'] + grouped['mlp_ens_std'],
            color='blue',
            alpha=0.2
        )

    if plot_wmlp_ensemble_metrics:    
        plt.plot(grouped['ens_size'], grouped['wmlp_ens_mean'], label='WMLP Ensemble', color='orange')
        plt.fill_between(
            grouped['ens_size'],
            grouped['wmlp_ens_mean'] - grouped['wmlp_ens_std'],
            grouped['wmlp_ens_mean'] + grouped['wmlp_ens_std'],
            color='orange',
            alpha=0.2
        )
    
    if plot_mean_mlp_metric:
        plt.plot(grouped['ens_size'], grouped['mean_mlp_metric_mean'], label='Mean MLP 64 models Metric', color='green')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mean_mlp_metric_mean'] - grouped['mean_mlp_metric_std'],
            grouped['mean_mlp_metric_mean'] + grouped['mean_mlp_metric_std'],
            color='green',
            alpha=0.2
        )
         
    
    if plot_mean_wmlp_metric:
        plt.plot(grouped['ens_size'], grouped['mean_wmlp_metric_mean'], label='Mean WMLP 64 models Metric', color='red')
        plt.fill_between(
            grouped['ens_size'],
            grouped['mean_wmlp_metric_mean'] - grouped['mean_wmlp_metric_std'],
            grouped['mean_wmlp_metric_mean'] + grouped['mean_wmlp_metric_std'],
            color='red',
            alpha=0.2
        )

    plt.title(f'Ensemble Performance on {dataset_name.capitalize()} with Hidden Dim {hidden_dim}')
    plt.xlabel('Ensemble Size')
    plt.ylabel(metric_label)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.xticks(grouped['ens_size'])
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{filename_base}_ensemble_metrics.png"))
    plt.close()
        

In [6]:
datasets = agg_data['dataset_name'].unique()

for dataset in datasets:
    hidden_dims = agg_data[agg_data['dataset_name'] == dataset]['hidden_dim'].unique()
    for hidden_dim in hidden_dims:
        plot_and_save(agg_data, dataset, hidden_dim)


In [11]:
import pandas as pd
import matplotlib.pyplot as plt
import os

def plot_moe_fix_total_hidden_size(csv_path, save_dir="../reports/figures/moe_fix_total_hidden_size"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    grouped = df.groupby(["dataset_name", "model_type", "num_experts"]).agg(
        mean_metric=("metric", "mean"),
        std_metric=("metric", "std")
    ).reset_index().sort_values(["dataset_name", "num_experts"])

    for dataset in grouped["dataset_name"].unique():
        subset = grouped[grouped["dataset_name"] == dataset]
        subset_mlp = subset[subset["model_type"] == "mlp"]
        subset_wmlp = subset[subset["model_type"] == "wmlp"]

        plt.figure(figsize=(8, 5))
        plt.plot(subset_mlp["num_experts"], subset_mlp["mean_metric"], label="MLP", color="blue")
        plt.fill_between(
            subset_mlp["num_experts"],
            subset_mlp["mean_metric"] - subset_mlp["std_metric"],
            subset_mlp["mean_metric"] + subset_mlp["std_metric"],
            color="blue", alpha=0.2
        )
        plt.plot(subset_wmlp["num_experts"], subset_wmlp["mean_metric"], label="wMLP", color="orange")
        plt.fill_between(
            subset_wmlp["num_experts"],
            subset_wmlp["mean_metric"] - subset_wmlp["std_metric"],
            subset_wmlp["mean_metric"] + subset_wmlp["std_metric"],
            color="orange", alpha=0.2
        )
        plt.title(f"{dataset} (moe_fix_total_hidden_size)")
        plt.xlabel("Number of Experts")
        plt.ylabel("Metric (mean ± std)")
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{dataset}_moe_fix_total_hidden_size.png"))
        plt.close()

def plot_moe_num_experts(csv_path, save_dir="../reports/figures/moe_num_experts"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    grouped = df.groupby(["dataset_name", "model_type", "num_experts"]).agg(
        mean_metric=("metric", "mean"),
        std_metric=("metric", "std")
    ).reset_index().sort_values(["dataset_name", "num_experts"])

    for dataset in grouped["dataset_name"].unique():
        subset = grouped[grouped["dataset_name"] == dataset]
        subset_mlp = subset[subset["model_type"] == "mlp"]
        subset_wmlp = subset[subset["model_type"] == "wmlp"]

        plt.figure(figsize=(8, 5))
        plt.plot(subset_mlp["num_experts"], subset_mlp["mean_metric"], label="MLP", color="blue")
        plt.fill_between(
            subset_mlp["num_experts"],
            subset_mlp["mean_metric"] - subset_mlp["std_metric"],
            subset_mlp["mean_metric"] + subset_mlp["std_metric"],
            color="blue", alpha=0.2
        )
        plt.plot(subset_wmlp["num_experts"], subset_wmlp["mean_metric"], label="wMLP", color="orange")
        plt.fill_between(
            subset_wmlp["num_experts"],
            subset_wmlp["mean_metric"] - subset_wmlp["std_metric"],
            subset_wmlp["mean_metric"] + subset_wmlp["std_metric"],
            color="orange", alpha=0.2
        )
        plt.title(f"{dataset} (moe_num_experts)")
        plt.xlabel("Number of Experts")
        plt.ylabel("Metric (mean ± std)")
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{dataset}_moe_num_experts.png"))
        plt.close()



In [12]:
csv_fix_total_hidden_size = "../data/processed/moe_fix_total_hidden_size.csv"
csv_num_experts = "../data/processed/moe_num_experts.csv"
plot_moe_fix_total_hidden_size(csv_fix_total_hidden_size)
plot_moe_num_experts(csv_num_experts)
