In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pandas as pd
import numpy as np
import os
import json

### Before using:Must run MMLU evaluation for the required models using open-instruct MMLU evaluation script (finetuning/open-instruct/eval/mmlu/run_eval_local.py)

In [None]:
def plot_model_accuracies(model_accuracies,baselines="olmo"):
    # Initialize an empty DataFrame
    data = []

    # Loop through the dictionary to populate the DataFrame
    for model_name, results in model_accuracies.items():
        for step, accuracy in results:
            data.append((model_name, step, accuracy))

    # Convert the list to a DataFrame
    df = pd.DataFrame(data, columns=['Model', 'Steps', 'Accuracy'])

    # Set the plot theme
    sns.set_theme(style="whitegrid")

    # Create the plot
    plt.figure(figsize=(20, 8))

    # Use a different color palette for the models
    palette = sns.color_palette("husl", len(model_accuracies))
    sns.lineplot(data=df, x='Steps', y='Accuracy', hue='Model', marker='o', palette=palette, linewidth=2.5)

    # Add horizontal lines for the baselines
    if baselines == "olmo":
        plt.axhline(y=28.6, color='black', linestyle='--', label='OLMo (base) - 28.6')
        plt.axhline(y=47.3, color='blue', linestyle='--', label='OLMo-SFT - 47.3')
    elif baselines == "t5":
        # 0-shot
        #plt.axhline(y=25.9, color='black', linestyle='--', label='T5 (base) - 25.9')
        #plt.axhline(y=55.1, color='blue', linestyle='--', label='Flan-T5 - 55.1')
        # 5-shot 
        plt.axhline(y=23.0, color='black', linestyle='--', label='T5 (5-shot) - 23.0')
        plt.axhline(y=54.6, color='blue', linestyle='--', label='Flan-T5 (5-shot) - 54.6')


    # Set plot labels and title
    plt.xlabel('Number of Steps')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy during Training')

    # Adjust the legend
    #plt.legend(title='Model and Baselines')
    #plt.legend(title='Model and Baselines', bbox_to_anchor=(1.05, 1.4), loc='upper left')
    #plt.legend(title='Model and Baselines', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=14)
    #plt.legend(title='Model and Baselines', bbox_to_anchor=(1.5, 1.5), loc='upper left')


    # change font size
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.legend(fontsize=14)
    plt.xlabel('Number of Steps', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('MMLU Accuracy during Training', fontsize=16)
    
    # Adjust layout to make room for the legend
    #plt.tight_layout(rect=[0, 0, 0.85, 1])
    #plt.tight_layout()

    plt.legend(title='Model and Baselines', bbox_to_anchor=(1.0, 1), loc='upper left', fontsize=14)

    # Adjust layout to make room for the legend
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    
    # Show the plot
    plt.show()

## Extract MMLU Scores

In [16]:
def extract_all_average_acc(model_names, main_dir="output/", base_model="T5", return_categories=False):
    results = {}
    subcat_results = {}  # New dictionary for subcategories
    cat_results = {}     # New dictionary for main categories
    
    init_mmlu_score = 23.0 if base_model == "T5" else 28.6
    mmlu_folder_name = "mmlu_5_shot" if base_model == "T5" else "mmlu"

    for model_name in model_names:
        model_dir = os.path.join(main_dir, model_name)
        model_results = [(0, init_mmlu_score)]  # Start with the initial value
        model_subcat_results = {}  # Track subcategory results for this model
        model_cat_results = {}     # Track category results for this model
        
        print(model_dir)
        if os.path.isdir(model_dir):
            for step_dir in os.listdir(model_dir):
                step_path = os.path.join(model_dir, step_dir)
                if os.path.isdir(step_path) and step_dir.startswith("step_"):
                    metrics_file = os.path.join(step_path, "merged", mmlu_folder_name, "metrics_merged.json")

                    if not os.path.isfile(metrics_file):
                        # Sometimes the metrics_merged.json is not in the mmlu_5_shot folder, but in the mmlu folder
                        metrics_file = os.path.join(step_path, "merged", "mmlu", "metrics_merged.json")
                    if os.path.isfile(metrics_file):
                        with open(metrics_file, 'r') as f:
                            data = json.load(f)
                            step_value = int(step_dir.split("_")[1])
                            
                            # Average accuracy
                            average_acc = round(data.get("average_acc", 0) * 100, 1)
                            model_results.append((step_value, average_acc))
                            
                            # Subject subcategories
                            if "subcat_acc" in data:
                                if step_value not in model_subcat_results:
                                    model_subcat_results[step_value] = {}
                                for subject, acc in data["subcat_acc"].items():
                                    model_subcat_results[step_value][subject] = round(acc * 100, 1)
                            
                            # Main categories
                            if "cat_acc" in data:
                                if step_value not in model_cat_results:
                                    model_cat_results[step_value] = {}
                                for category, acc in data["cat_acc"].items():
                                    model_cat_results[step_value][category] = round(acc * 100, 1)

        results[model_name] = model_results
        subcat_results[model_name] = model_subcat_results
        cat_results[model_name] = model_cat_results

    if return_categories:
        return results, subcat_results, cat_results
    return results

def print_category_tables(subcat_results, cat_results, model_names, step_to_analyze=None):
    """
    Print formatted tables of accuracies and variances for categories and subcategories.
    
    Args:
        subcat_results: Dictionary of subcategory results from extract_average_acc
        cat_results: Dictionary of category results from extract_average_acc
        step_to_analyze: Specific step to analyze. If None, uses the last step for each model
    """
    
    def create_table(results_dict, level="category"):
        # Initialize storage for means and variances
        all_metrics = {}
        
        for model_name in results_dict.keys():
            # If step not specified, use the last available step
            print(f"model_name: {model_name}")
            if step_to_analyze is None:
                step = max(results_dict[model_name].keys())
            else:
                step = step_to_analyze
                
            if step not in results_dict[model_name]:
                print(f"Warning: Step {step} not found for model {model_name}")
                continue
                
            metrics = results_dict[model_name][step]
            all_metrics[model_name] = metrics
        
        # Convert to DataFrame
        df = pd.DataFrame(all_metrics)
        
        # Sort index alphabetically
        df = df.sort_index()
        
        # Calculate mean and std across models
        df['Mean'] = df.mean(axis=1)
        df['Std'] = df.std(axis=1)
        
        # Round all values to 1 decimal place
        df = df.round(1)
        
        # Format the table
        if level == "category":
            print(f"\n{'='*80}\nMain Category Results (Step {step})\n{'='*80}")
        else:
            print(f"\n{'='*80}\nSubcategory Results (Step {step})\n{'='*80}")
            
        # Calculate column means
        model_means = df.mean()
        model_stds = df.std()
        
        # Add model average and std row
        df.loc['Model Average'] = model_means
        
        # Format the DataFrame for display
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', None)
        pd.set_option('display.max_rows', None)
        
        print(df)
        print("\nModel Averages:")
        for column in df.columns:
            if column not in ['Mean', 'Std']:
                print(f"{column}: {model_means[column]:.1f} ± {model_stds[column]:.1f}")
        
        # Print the average std of the models considering accuracy in 0.XX format
        df_accuracy = df.applymap(lambda x: x/100)
        #print(f"df_accuracy: {df_accuracy}")
        model_stds_in_accuracy_format = df_accuracy.std()
        print(f"Model Average Std in accuracy format: {model_stds_in_accuracy_format.mean()}")
    
    # Convert the model names to the original model names
    model_names = {model_names[key]: key for key in model_names}
    
    # Create and print both tables
    create_table(cat_results, level="category")
    create_table(subcat_results, level="subcategory")


base_model: T5
output/allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4
output/allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4_seed_1
output/allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4_seed_2
model_name: allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4
model_name: allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4_seed_1
model_name: allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4_seed_2

Main Category Results (Step 15000)
                                 allenai/tulu-v2-sft-mixture_t5-v1_1-xxl_lora_r128_alpha256_LR1e-4  \
STEM                                                                         38.30                   
humanities                                                                   45.60                   
other (business, health, misc.)                                              54.40                   
social sciences                                             

  df_accuracy = df.applymap(lambda x: x/100)
  df_accuracy = df.applymap(lambda x: x/100)


In [None]:
model_names = {
    "path/to/model_name":"T5-Tulu-Seed-0",
    "path/to/model_name":"T5-Tulu-Seed-1",
    "path/to/model_name":"T5-Tulu-Seed-2",
}
if 'T5' in list(model_names.values())[0]:
    base_model = "T5"
else:
    base_model = "OLMo"
print(f"base_model: {base_model}")

# First get the results using the modified extract_average_acc
results, subcat_results, cat_results = extract_all_average_acc(model_names.keys(), main_dir="output/", base_model=base_model, return_categories=True)

# Then print the formatted tables
print_category_tables(subcat_results, cat_results, model_names)

# Or for a specific step:
#print_category_tables(subcat_results, cat_results, model_names, step_to_analyze=2500)