In [1]:
import sys
from pathlib import Path
import os
import matplotlib.pyplot as plt
import pandas as pd

def find_project_root(start_path: Path = Path.cwd(), marker: str = 'pyproject.toml') -> Path:
    current_path = start_path.resolve()
    for parent in [current_path] + list(current_path.parents):
        if (parent / marker).exists():
            return parent
        
def add_project_root_to_sys_path(marker: str = 'pyproject.toml'):
    project_root = find_project_root(marker=marker)
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

add_project_root_to_sys_path()


In [3]:
import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.asym_ensembles.data_loaders import load_dataset
from src.asym_ensembles.modeling.training import (
    set_global_seed,
    train_one_model,
    evaluate_model,
    evaluate_ensemble,
    average_pairwise_distance
)
from src.asym_ensembles.modeling.models import MLP, WMLP
import numpy as np
import copy
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [34]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker

# Глобальные параметры: Times New Roman, жирный, увеличенный размер шрифта
mpl.rcParams['font.family'] = 'Times New Roman'
mpl.rcParams['font.weight'] = 'bold'
mpl.rcParams['font.size'] = 16

# Используем только следующие датасеты
DATASETS = ["california", "otto", "mnist", "adult", "churn"]

# =============================
# 1. Composite figure for deep ensemble percentage improvement
#    (3 подграфика: по hidden size [64, 128, 256], улучшение, усредненное по всем датасетам)
# =============================
def plot_deep_ensemble_composite_improvement(agg_csv="../data/processed/agg3.csv", save_dir="../reports/figures"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(agg_csv)
    df = df[df['dataset_name'].str.lower().isin([d.lower() for d in DATASETS])]

    # Приводим к числовому типу необходимые столбцы
    for col in ['mlp_ens_metric', 'wmlp_ens_metric', 'mean_mlp_metric', 'mean_wmlp_metric', 'ens_size']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    
    # Группируем по ключам; агрегируем только требуемые столбцы
    grouped = df.groupby(['dataset_name', 'hidden_dim', 'ens_size', 'metric_type'], as_index=False).agg({
        'mlp_ens_metric': 'mean',
        'wmlp_ens_metric': 'mean',
        'mean_mlp_metric': 'mean',
        'mean_wmlp_metric': 'mean'
    })
    
    def compute_improvement(row):
        if row['metric_type'] == 'rmse':
            mlp_imp = (row['mean_mlp_metric'] - row['mlp_ens_metric']) / row['mean_mlp_metric'] * 100
            wmlp_imp = (row['mean_wmlp_metric'] - row['wmlp_ens_metric']) / row['mean_wmlp_metric'] * 100
        else:
            mlp_imp = (row['mlp_ens_metric'] - row['mean_mlp_metric']) / row['mean_mlp_metric'] * 100
            wmlp_imp = (row['wmlp_ens_metric'] - row['mean_wmlp_metric']) / row['mean_wmlp_metric'] * 100
        return pd.Series({'mlp_imp': mlp_imp, 'wmlp_imp': wmlp_imp})
    
    improvements = grouped.apply(compute_improvement, axis=1)
    grouped = pd.concat([grouped, improvements], axis=1)
    
    # Усредняем по всем датасетам для каждой комбинации hidden_dim и ens_size
    avg_imp = grouped.groupby(['hidden_dim', 'ens_size'], as_index=False)[['mlp_imp', 'wmlp_imp']].mean()
    
    hidden_sizes = [64, 128, 256]
    fig, axs = plt.subplots(1, len(hidden_sizes), figsize=(5 * len(hidden_sizes), 4))
    for ax, hs in zip(axs, hidden_sizes):
        sub = avg_imp[avg_imp['hidden_dim'] == hs]
        ax.plot(sub['ens_size'], sub['mlp_imp'], marker='o', label='MLP ENSEMBLE', color='blue', linewidth=2)
        ax.plot(sub['ens_size'], sub['wmlp_imp'], marker='s', label='WMLP ENSEMBLE', color='orange', linewidth=2)
        ax.set_title(f"HIDDEN SIZE = {hs}", fontsize=16, fontweight='bold')
        ax.set_xlabel("Ensemble Size", fontsize=16, fontweight='bold')
        ax.set_ylabel("Improvement (%)", fontsize=16, fontweight='bold')
        ax.grid(True, linestyle='--', alpha=0.5)
        # Устанавливаем логарифмический масштаб по оси x с основанием 2
        ax.set_xscale('log', base=2)
        # Для наглядности можно установить тики по существующим значениям:
        xticks = sorted(sub['ens_size'].unique())
        ax.set_xticks(xticks)
        ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
        plt.setp(ax.get_xticklabels(), fontweight='bold')
        plt.setp(ax.get_yticklabels(), fontweight='bold')
    
    # Легенда – фиксированный порядок
    desired_order = ['MLP ENSEMBLE', 'WMLP ENSEMBLE']
    handles, labels = axs[0].get_legend_handles_labels()
    ordered_handles = []
    ordered_labels = []
    for lab in desired_order:
        for h, l in zip(handles, labels):
            if l == lab:
                ordered_handles.append(h)
                ordered_labels.append(l)
                break
    fig.legend(ordered_handles, ordered_labels, loc="upper center", ncol=2, frameon=False, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    composite_path = os.path.join(save_dir, "deep_ensemble_improvement_composite.pdf")
    plt.savefig(composite_path, dpi=300)
    plt.close()


# =============================
# 2. Composite figure for MoE/MoIE percentage improvement
#    (2 подграфика: по GATING type [STANDARD, GUMBEL], для MODEL_TYPE: MLP, WMLP, IMLP, IWMLP)
# =============================
def plot_moe_composite_improvement(csv_path="../data/processed/table3_gumbel_standart_64fix.csv", 
                                     save_dir="../reports/figures/moe"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    df = df[df['dataset_name'].str.lower().isin([d.lower() for d in DATASETS])]
    for col in ['metric', 'num_experts', 'hidden_dim']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    df = df[df['hidden_dim'] == 64]
    
    improvements = []
    # Для каждого gating, dataset и model – вычисляем improvement относительно baseline (num_experts=2)
    for gating in df['gating_type'].unique():
        for dataset in df['dataset_name'].unique():
            for model in df['model_type'].unique():
                sub = df[(df['gating_type'] == gating) &
                         (df['dataset_name'] == dataset) &
                         (df['model_type'] == model)]
                if sub.empty or (sub['num_experts'] == 2).sum() == 0:
                    continue
                baseline = sub[sub['num_experts'] == 2]['metric'].mean()
                sub = sub.sort_values('num_experts').copy()
                if sub['metric_type'].iloc[0] == 'rmse':
                    sub['improvement'] = (baseline - sub['metric']) / baseline * 100
                else:
                    sub['improvement'] = (sub['metric'] - baseline) / baseline * 100
                improvements.append(sub)
    if improvements:
        imp_df = pd.concat(improvements)
    else:
        print("No data for MoE improvement composite.")
        return
    # Группируем только столбец 'improvement'
    avg_imp = imp_df.groupby(['gating_type', 'model_type', 'num_experts'], as_index=False)['improvement'].mean()
    
    model_order = ['mlp', 'wmlp', 'imlp', 'iwmlp']
    # Приводим метки модели к верхнему регистру
    model_order_upper = [m.upper() for m in model_order]
    model_colors = {"mlp": "blue", "wmlp": "orange", "imlp": "green", "iwmlp": "red"}
    
    # Фиксированный порядок gating type: сначала standard, затем gumbel
    gating_types = ["standard", "gumbel"]
    fig, axs = plt.subplots(1, len(gating_types), figsize=(6 * len(gating_types), 4))
    for ax, gating in zip(axs, gating_types):
        sub_gate = avg_imp[avg_imp['gating_type'] == gating]
        for model in model_order:
            sub_model = sub_gate[sub_gate['model_type'] == model]
            if sub_model.empty:
                continue
            ax.plot(sub_model['num_experts'], sub_model['improvement'], marker='o', label=model.upper(),
                    color=model_colors.get(model, None), linewidth=2)
        ax.set_title(f"GATING: {gating.upper()}", fontsize=16, fontweight='bold')
        ax.set_xlabel("Number of Experts", fontsize=16, fontweight='bold')
        ax.set_ylabel("Improvement (%)", fontsize=16, fontweight='bold')
        ax.grid(True, linestyle='--', alpha=0.5)
        # Логарифмический масштаб по оси x с основанием 2
        ax.set_xscale('log', base=2)
        xticks = sorted(sub_gate['num_experts'].unique())
        ax.set_xticks(xticks)
        ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
        plt.setp(ax.get_xticklabels(), fontweight='bold')
        plt.setp(ax.get_yticklabels(), fontweight='bold')
    handles, labels = axs[0].get_legend_handles_labels()
    ordered_handles = []
    ordered_labels = []
    for model in model_order_upper:
        for h, l in zip(handles, labels):
            if l == model:
                ordered_handles.append(h)
                ordered_labels.append(l)
                break
    fig.legend(ordered_handles, ordered_labels, loc="upper center", ncol=4, frameon=False, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    composite_path = os.path.join(save_dir, "moe_improvement_composite.pdf")
    plt.savefig(composite_path, dpi=300)
    plt.close()


# =============================
# 3. Composite figure for deep ensemble absolute metric values
#    (Grid: Columns = Hidden Size [64,128,256]; Rows = Datasets: california, otto, mnist, adult, churn)
#    Доверительные интервалы строятся по среднему ± std, и добавляются линии базового уровня:
#    (Mean MLP 64 Models Metric и Mean WMLP 64 Models Metric)
# =============================
def plot_deep_ensemble_absolute(agg_csv="../data/processed/agg3.csv", save_dir="../reports/figures"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(agg_csv)
    for col in ['mlp_ens_metric', 'wmlp_ens_metric', 'ens_size', 'mean_mlp_metric', 'mean_wmlp_metric']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    # Фиксированный порядок датасетов
    datasets = ["california", "otto", "mnist", "adult", "churn"]
    hidden_sizes = [64, 128, 256]
    nrows, ncols = len(datasets), len(hidden_sizes)
    fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows))
    for i, ds in enumerate(datasets):
        for j, hs in enumerate(hidden_sizes):
            ax = axs[i, j] if (nrows > 1 and ncols > 1) else axs[max(i, j)]
            sub = df[(df['dataset_name'].str.lower() == ds) & (df['hidden_dim'] == hs)]
            if sub.empty:
                ax.set_title(f"{ds}, Hidden Size = {hs}\n(No Data)", fontsize=12)
                continue
            sub = sub.sort_values('ens_size')
            # Ensemble метрики
            grouped = sub.groupby('ens_size').agg(
                mlp_ens_mean=('mlp_ens_metric', 'mean'),
                mlp_ens_std=('mlp_ens_metric', 'std'),
                wmlp_ens_mean=('wmlp_ens_metric', 'mean'),
                wmlp_ens_std=('wmlp_ens_metric', 'std')
            ).reset_index()
            # Базовые метрики (средние по 64 моделям)
            grouped_baseline = sub.groupby('ens_size').agg(
                mean_mlp_metric_mean=('mean_mlp_metric', 'mean'),
                mean_mlp_metric_std=('mean_mlp_metric', 'std'),
                mean_wmlp_metric_mean=('mean_wmlp_metric', 'mean'),
                mean_wmlp_metric_std=('mean_wmlp_metric', 'std')
            ).reset_index()
            ax.plot(grouped['ens_size'], grouped['mlp_ens_mean'], marker='o', label='MLP ENSEMBLE', color='blue', linewidth=2)
            ax.fill_between(grouped['ens_size'],
                            grouped['mlp_ens_mean'] - grouped['mlp_ens_std'],
                            grouped['mlp_ens_mean'] + grouped['mlp_ens_std'],
                            color='blue', alpha=0.2)
            ax.plot(grouped['ens_size'], grouped['wmlp_ens_mean'], marker='s', label='WMLP ENSEMBLE', color='orange', linewidth=2)
            ax.fill_between(grouped['ens_size'],
                            grouped['wmlp_ens_mean'] - grouped['wmlp_ens_std'],
                            grouped['wmlp_ens_mean'] + grouped['wmlp_ens_std'],
                            color='orange', alpha=0.2)
            # Базовые линии
            ax.plot(grouped_baseline['ens_size'], grouped_baseline['mean_mlp_metric_mean'], 
                    label='MEAN MLP 64 MODELS METRIC', color='green', linewidth=2)
            ax.fill_between(grouped_baseline['ens_size'],
                            grouped_baseline['mean_mlp_metric_mean'] - grouped_baseline['mean_mlp_metric_std'],
                            grouped_baseline['mean_mlp_metric_mean'] + grouped_baseline['mean_mlp_metric_std'],
                            color='green', alpha=0.2)
            ax.plot(grouped_baseline['ens_size'], grouped_baseline['mean_wmlp_metric_mean'], 
                    label='MEAN WMLP 64 MODELS METRIC', color='red', linewidth=2)
            ax.fill_between(grouped_baseline['ens_size'],
                            grouped_baseline['mean_wmlp_metric_mean'] - grouped_baseline['mean_wmlp_metric_std'],
                            grouped_baseline['mean_wmlp_metric_mean'] + grouped_baseline['mean_wmlp_metric_std'],
                            color='red', alpha=0.2)
            ax.set_title(f"{ds.upper()}, HIDDEN SIZE = {hs}", fontsize=14, fontweight='bold')
            ax.grid(True, linestyle='--', alpha=0.5)
            # Логарифмический масштаб по оси x
            ax.set_xscale('log', base=2)
            xticks = sorted(grouped['ens_size'].unique())
            ax.set_xticks(xticks)
            ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
            plt.setp(ax.get_xticklabels(), fontweight='bold')
            plt.setp(ax.get_yticklabels(), fontweight='bold')
            # Подписи осей
            if j == 0:
                if ds.lower() == "california":
                    ax.set_ylabel("RMSE", fontsize=14, fontweight='bold')
                else:
                    ax.set_ylabel("Accuracy", fontsize=14, fontweight='bold')
            if i == nrows - 1:
                ax.set_xlabel("Ensemble Size", fontsize=14, fontweight='bold')
    # Легенда – фиксированный порядок
    handles, labels = axs[0,0].get_legend_handles_labels() if (nrows > 1 and ncols > 1) else ax.get_legend_handles_labels()
    desired_order = ['MLP ENSEMBLE', 'WMLP ENSEMBLE', 'MEAN MLP 64 MODELS METRIC', 'MEAN WMLP 64 MODELS METRIC']
    ordered_handles = []
    ordered_labels = []
    for lab in desired_order:
        for h, l in zip(handles, labels):
            if l == lab:
                ordered_handles.append(h)
                ordered_labels.append(l)
                break
    fig.legend(ordered_handles, ordered_labels, loc="upper center", ncol=2, frameon=False, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    composite_path = os.path.join(save_dir, "deep_ensemble_absolute_composite.pdf")
    plt.savefig(composite_path, dpi=300)
    plt.close()


# =============================
# 4. Composite figure for MoE/MoIE absolute metric values
#    (Grid: столбцы = GATING TYPE [STANDARD, GUMBEL]; строки = Дatasets: california, otto, mnist, adult, churn)
# =============================
def plot_moe_absolute(csv_path="../data/processed/table3_gumbel_standart_64fix.csv", 
                      save_dir="../reports/figures/moe"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    for col in ['metric', 'num_experts', 'hidden_dim']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    grouped = df.groupby(["dataset_name", "gating_type", "model_type", "num_experts", "hidden_dim", "metric_type"],
                         as_index=False).agg(
        mean_metric=("metric", "mean"),
        std_metric=("metric", "std")
    )
    # Фиксированный порядок датасетов и gating types
    datasets = ["california", "otto", "mnist", "adult", "churn"]
    gating_types = ["standard", "gumbel"]
    nrows, ncols = len(datasets), len(gating_types)
    fig, axs = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4 * nrows))
    colors = {"mlp": "blue", "wmlp": "orange", "imlp": "green", "iwmlp": "red"}
    for i, ds in enumerate(datasets):
        for j, gate in enumerate(gating_types):
            ax = axs[i, j] if (nrows > 1 and ncols > 1) else axs[max(i, j)]
            sub = grouped[(grouped['dataset_name'].str.lower() == ds) & (grouped['gating_type'] == gate)]
            if sub.empty:
                ax.set_title(f"{ds.upper()}, GATING: {gate.upper()}\n(No Data)", fontsize=12)
                continue
            for model in ['mlp', 'wmlp', 'imlp', 'iwmlp']:
                sub_model = sub[sub['model_type'] == model].sort_values("num_experts")
                if sub_model.empty:
                    continue
                ax.plot(sub_model['num_experts'], sub_model['mean_metric'], marker='o',
                        label=model.upper(), color=colors.get(model, None), linewidth=2)
                ax.fill_between(sub_model['num_experts'],
                                sub_model['mean_metric'] - sub_model['std_metric'],
                                sub_model['mean_metric'] + sub_model['std_metric'],
                                color=colors.get(model, None), alpha=0.2)
            ax.set_title(f"{ds.upper()}, GATING: {gate.upper()}", fontsize=12, fontweight='bold')
            ax.grid(True, linestyle='--', alpha=0.5)
            # Логарифмический масштаб по оси x
            ax.set_xscale('log', base=2)
            xticks = sorted(sub['num_experts'].unique())
            ax.set_xticks(xticks)
            ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
            plt.setp(ax.get_xticklabels(), fontweight='bold')
            plt.setp(ax.get_yticklabels(), fontweight='bold')
            if i == nrows - 1:
                ax.set_xlabel("Number of Experts", fontsize=12, fontweight='bold')
            if j == 0:
                if ds.lower() == "california":
                    ax.set_ylabel("RMSE", fontsize=12, fontweight='bold')
                else:
                    ax.set_ylabel("Accuracy", fontsize=12, fontweight='bold')
    try:
        handles, labels = axs[0,0].get_legend_handles_labels()
    except:
        handles, labels = ax.get_legend_handles_labels()
    desired_order = ['MLP', 'WMLP', 'IMLP', 'IWMLP']
    ordered_handles = []
    ordered_labels = []
    for lab in desired_order:
        for h, l in zip(handles, labels):
            if l == lab:
                ordered_handles.append(h)
                ordered_labels.append(l)
                break
    fig.legend(ordered_handles, ordered_labels, loc="upper center", ncol=len(desired_order), fontsize=12, frameon=False)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    composite_path = os.path.join(save_dir, "moe_absolute_composite.pdf")
    plt.savefig(composite_path, dpi=300)
    plt.close()


# =============================
# 5. Deep Ensemble Relative Improvement (ALL DATASETS)
#    (Grid: Rows = Datasets, Columns = Hidden Size [64, 128, 256])
# =============================
def plot_deep_ensemble_relative_improvement_all(agg_csv="../data/processed/agg3.csv", save_dir="../reports/figures"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(agg_csv)
    for col in ['mlp_ens_metric', 'wmlp_ens_metric', 'mean_mlp_metric', 'mean_wmlp_metric', 'ens_size']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    df = df[df['dataset_name'].str.lower().isin([d.lower() for d in DATASETS])]
    
    # Фиксированный порядок датасетов (как в DATASETS) и hidden_sizes (предполагается, что их значения совпадают)
    datasets = DATASETS
    hidden_sizes = sorted(df['hidden_dim'].unique())
    
    grouped = df.groupby(['dataset_name', 'hidden_dim', 'ens_size', 'metric_type'], as_index=False).agg({
        'mlp_ens_metric': 'mean',
        'wmlp_ens_metric': 'mean',
        'mean_mlp_metric': 'mean',
        'mean_wmlp_metric': 'mean'
    })
    
    def compute_improvement(row):
        if row['metric_type'] == 'rmse':
            mlp_imp = (row['mean_mlp_metric'] - row['mlp_ens_metric']) / row['mean_mlp_metric'] * 100
            wmlp_imp = (row['mean_wmlp_metric'] - row['wmlp_ens_metric']) / row['mean_wmlp_metric'] * 100
        else:
            mlp_imp = (row['mlp_ens_metric'] - row['mean_mlp_metric']) / row['mean_mlp_metric'] * 100
            wmlp_imp = (row['wmlp_ens_metric'] - row['mean_wmlp_metric']) / row['mean_wmlp_metric'] * 100
        return pd.Series({'MLP_REL': mlp_imp, 'WMLP_REL': wmlp_imp})
    improvements = grouped.apply(compute_improvement, axis=1)
    grouped = pd.concat([grouped, improvements], axis=1)
    
    avg_imp = grouped.groupby(['dataset_name', 'hidden_dim', 'ens_size'], as_index=False)[['MLP_REL', 'WMLP_REL']].mean()
    
    nrows = len(datasets)
    ncols = len(hidden_sizes)
    fig, axs = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False)
    
    for i, ds in enumerate(datasets):
        for j, hs in enumerate(hidden_sizes):
            ax = axs[i][j]
            sub = avg_imp[(avg_imp['dataset_name'].str.lower() == ds.lower()) & (avg_imp['hidden_dim'] == hs)]
            if sub.empty:
                ax.set_title(f"{ds.upper()}, HIDDEN SIZE = {hs}\n(No Data)", fontsize=16)
                continue
            ax.plot(sub['ens_size'], sub['MLP_REL'], marker='o', label='MLP ENSEMBLE', color='blue', linewidth=2)
            ax.plot(sub['ens_size'], sub['WMLP_REL'], marker='s', label='WMLP ENSEMBLE', color='orange', linewidth=2)
            ax.set_title(f"{ds.upper()}, HIDDEN SIZE = {hs}", fontsize=16, fontweight='bold')
            ax.set_xlabel("Ensemble Size", fontsize=16, fontweight='bold')
            ax.set_ylabel("Improvement (%)", fontsize=16, fontweight='bold')
            ax.grid(True, linestyle='--', alpha=0.5)
            # Логарифмический масштаб по оси x
            ax.set_xscale('log', base=2)
            xticks = sorted(sub['ens_size'].unique())
            ax.set_xticks(xticks)
            ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
            plt.setp(ax.get_xticklabels(), fontweight='bold')
            plt.setp(ax.get_yticklabels(), fontweight='bold')
    handles, labels = axs[0][0].get_legend_handles_labels()
    desired_order = ['MLP ENSEMBLE', 'WMLP ENSEMBLE']
    ordered_handles = []
    ordered_labels = []
    for lab in desired_order:
        for h, l in zip(handles, labels):
            if l == lab:
                ordered_handles.append(h)
                ordered_labels.append(l)
                break
    fig.legend(ordered_handles, ordered_labels, loc="upper center", ncol=2, frameon=False, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    filename = os.path.join(save_dir, "deep_ensemble_relative_improvement_all.pdf")
    plt.savefig(filename, dpi=300)
    plt.close()


# =============================
# 6. MoE Relative Improvement (ALL DATASETS)
#    (Grid: Rows = Datasets, Columns = GATING TYPE [STANDARD, GUMBEL])
# =============================
def plot_moe_relative_improvement_all(csv_path="../data/processed/table3_gumbel_standart_64fix.csv", 
                                        save_dir="../reports/figures/moe"):
    os.makedirs(save_dir, exist_ok=True)
    df = pd.read_csv(csv_path)
    for col in ['metric', 'num_experts', 'hidden_dim']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    df = df[df['hidden_dim'] == 64]
    df = df[df['dataset_name'].str.lower().isin([d.lower() for d in DATASETS])]
    
    # Фиксированный порядок: датасеты по DATASETS и gating_types как ниже
    datasets = DATASETS
    gating_types = ["standard", "gumbel"]
    
    def compute_improvement_moe(sub):
        metric_type = sub['metric_type'].iloc[0]
        baseline = sub[sub['num_experts'] == 2]['metric'].mean()
        if metric_type == 'rmse':
            sub['IMPROVEMENT'] = (baseline - sub['metric']) / baseline * 100
        else:
            sub['IMPROVEMENT'] = (sub['metric'] - baseline) / baseline * 100
        return sub
    df_impr = df.groupby(['dataset_name', 'gating_type', 'model_type'], group_keys=False).apply(compute_improvement_moe)
    
    avg_impr = df_impr.groupby(['dataset_name', 'gating_type', 'model_type', 'num_experts'], as_index=False)['IMPROVEMENT'].mean()
    
    nrows = len(datasets)
    ncols = len(gating_types)
    fig, axs = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4 * nrows), squeeze=False)
    model_order = ['mlp', 'wmlp', 'imlp', 'iwmlp']
    model_order_upper = [m.upper() for m in model_order]
    model_colors = {"mlp": "blue", "wmlp": "orange", "imlp": "green", "iwmlp": "red"}
    
    for i, ds in enumerate(datasets):
        for j, gate in enumerate(gating_types):
            ax = axs[i][j]
            sub = avg_impr[(avg_impr['dataset_name'].str.lower() == ds.lower()) & (avg_impr['gating_type'] == gate)]
            if sub.empty:
                ax.set_title(f"{ds.upper()}, GATING: {gate.upper()}\n(No Data)", fontsize=16)
                continue
            for model in model_order:
                sub_model = sub[sub['model_type'] == model]
                if sub_model.empty:
                    continue
                ax.plot(sub_model['num_experts'], sub_model['IMPROVEMENT'], marker='o', label=model.upper(),
                        color=model_colors.get(model, None), linewidth=2)
            ax.set_title(f"{ds.upper()}, GATING: {gate.upper()}", fontsize=16, fontweight='bold')
            ax.set_xlabel("Number of Experts", fontsize=16, fontweight='bold')
            ax.set_ylabel("Improvement (%)", fontsize=16, fontweight='bold')
            ax.grid(True, linestyle='--', alpha=0.5)
            # Логарифмический масштаб по оси x
            ax.set_xscale('log', base=2)
            xticks = sorted(sub['num_experts'].unique())
            ax.set_xticks(xticks)
            ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
            plt.setp(ax.get_xticklabels(), fontweight='bold')
            plt.setp(ax.get_yticklabels(), fontweight='bold')
    handles, labels = axs[0][0].get_legend_handles_labels()
    ordered_handles = []
    ordered_labels = []
    for lab in model_order_upper:
        for h, l in zip(handles, labels):
            if l == lab:
                ordered_handles.append(h)
                ordered_labels.append(l)
                break
    fig.legend(ordered_handles, ordered_labels, loc="upper center", ncol=len(model_order_upper), fontsize=16, frameon=False)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    filename = os.path.join(save_dir, "moe_relative_improvement_all.pdf")
    plt.savefig(filename, dpi=300)
    plt.close()


# Вызов функций для построения всех фигур
plot_deep_ensemble_composite_improvement()
plot_moe_composite_improvement()
plot_deep_ensemble_absolute()
plot_moe_absolute()
plot_deep_ensemble_relative_improvement_all()
plot_moe_relative_improvement_all()


  df_impr = df.groupby(['dataset_name', 'gating_type', 'model_type'], group_keys=False).apply(compute_improvement_moe)
