# Visualizations for evaluation and failure analysis

## Setup

In [2]:
import os
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from google.cloud import storage
import io

GCS_BUCKET_NAME = "open-llm-finetuning"
GCS_EVAL_PATH = f"gs://{GCS_BUCKET_NAME}/data/evaluation"
GCS_OUTPUT_PATH = f"gs://{GCS_BUCKET_NAME}/outputs"

## Define functions

Here we define functions to read the .json files containing the evaluation results from Google Cloud storage, we then process the data to export the visualization plots.

In [3]:
def load_all_data_from_gcs_directory(gcs_directory_path):
    all_records = []
    try:
        storage_client = storage.Client()
        bucket_name = gcs_directory_path.split('/')[2]
        prefix = '/'.join(gcs_directory_path.split('/')[3:]) + '/'
        
        blobs = storage_client.list_blobs(bucket_name, prefix=prefix)

        for blob in blobs:
            if blob.name.endswith('.json'):
                print(f"Reading {blob.name}...")
                json_data = blob.download_as_text()
                data = json.loads(json_data)

                config_general = data['config_general']
                results = data['results']
                
                for result_title, metrics in results.items():
                    if result_title == 'all':
                        continue
                    
                    accuracy = metrics.get('qem') or metrics.get('acc_norm') or metrics.get('acc')

                    all_records.append({
                        'model_name': config_general['model_name'],
                        'accuracy': accuracy,
                        'total_seconds': float(config_general['total_evaluation_time_secondes']),
                        'source_file': os.path.basename(blob.name)
                    })
    except Exception as e:
        print(f"Error accessing GCS. Details: {e}")
        return pd.DataFrame()

    return pd.DataFrame(all_records)

def save_plot_to_gcs(fig, gcs_output_path, filename):
    storage_client = storage.Client()
    bucket_name = gcs_output_path.split('/')[2]
    bucket = storage_client.bucket(bucket_name)
        
    img_data = io.BytesIO()
    fig.savefig(img_data, format='png', bbox_inches='tight')
    img_data.seek(0)
        
    blob_path = os.path.join('/'.join(gcs_output_path.split('/')[3:]), filename)
    blob = bucket.blob(blob_path)
        
    blob.upload_from_file(img_data, content_type='image/png')

def generate_time_analysis_plots(df, gcs_output_path):
    df_avg = df.groupby('model_name').agg({'total_seconds': 'mean'}).reset_index()
    df_avg['short_model_name'] = df_avg['model_name'].apply(lambda x: x.split('/')[-1].upper())
    df_avg['total_minutes'] = df_avg['total_seconds'] / 60
    
    df_with_letters = df_avg[df_avg['short_model_name'].str.contains("LETTERS")]
    df_without_letters = df_avg[~df_avg['short_model_name'].str.contains("LETTERS")]
    
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    
    sns.barplot(x='total_minutes', y='short_model_name', data=df_with_letters.sort_values('total_minutes', ascending=False), palette='viridis', ax=axes[0])
    axes[0].set_title('Generative Evaluation Time', fontsize=12, weight='bold')

    sns.barplot(x='total_minutes', y='short_model_name', data=df_without_letters.sort_values('total_minutes', ascending=False), palette='viridis', ax=axes[1])
    axes[1].set_title('Log-Likelihood Evaluation Time', fontsize=12, weight='bold')

    for ax in axes:
        ax.set_xlabel('Time (Minutes)', fontsize=12)
        ax.set_ylabel('Model', fontsize=12)
        for p in ax.patches:
            ax.text(p.get_width() + 0.1, p.get_y() + p.get_height() / 2, f' {p.get_width():.1f} min', va='center', ha='left')

    plt.suptitle('Average Evaluation Time per Model Type', fontsize=20, weight='bold')
    
    save_plot_to_gcs(fig, gcs_output_path, 'average_evaluation_time_combined.png')
    plt.close(fig)


def generate_comparison_plots(df, gcs_output_path):
    def get_accuracy_from_df(filename, dataframe):
        search_name = os.path.basename(filename)
        return dataframe[dataframe['source_file'] == search_name]['accuracy'].iloc[0]

    def create_plot(title, data_files, dataframe, output_filename, palette_name):
        labels = list(data_files.keys())
        palette = sns.color_palette(palette_name, 2)
        
        baseline_scores = [get_accuracy_from_df(data_files[cat]['Baseline'], dataframe) for cat in labels]
        qlora_scores = [get_accuracy_from_df(data_files[cat]['QLORA'], dataframe) for cat in labels]

        x = np.arange(len(labels))
        width = 0.35
        fig, ax = plt.subplots(figsize=(10, 7))
        
        rects1 = ax.bar(x - width/2, baseline_scores, width, label='Baseline (Llama-3-8B)', color=palette[0])
        rects2 = ax.bar(x + width/2, qlora_scores, width, label='Fine-tuned (QLORA)', color=palette[1])
        
        ax.set_ylabel('Accuracy', fontsize=12)
        ax.set_title(title, fontsize=16, weight='bold', pad=20)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, fontsize=12)
        ax.legend(fontsize=11)
        ax.set_ylim(0, max(1.0, max(baseline_scores + qlora_scores) * 1.2))

        ax.bar_label(rects1, padding=3, fmt='%.4f')
        ax.bar_label(rects2, padding=3, fmt='%.4f')
        
        save_plot_to_gcs(fig, gcs_output_path, output_filename)
        plt.close(fig)

    log_likelihood_files = {
        'Cloze Formulation': {
            'Baseline': 'evaluation_results_results_meta-llama_Meta-Llama-3-8B_results_2025-07-21T07-39-28.418935.json',
            'QLORA': 'evaluation_results_results_jihbr_usmle-llama8b-qlora_results_2025-07-16T07-37-57.051004.json'
        },
        'Multiple Choice Formulation': {
            'Baseline': 'evaluation_results_results_meta-llama_Meta-Llama-3-8B_results_2025-07-21T07-18-22.495821.json',
            'QLORA': 'evaluation_results_results_jihbr_usmle-llama8b-qlora_results_2025-07-16T13-38-59.159732.json'
        }
    }
    create_plot('Log-Likelihood Evaluation: Baseline vs. QLORA', log_likelihood_files, df, 'log_likelihood_comparison.png', 'viridis')

    generative_files = {
        'Letter Completion': {
            'Baseline': 'evaluation_results_results_meta-llama_Meta-Llama-3-8B_results_2025-07-19T19-30-24.638550.json',
            'QLORA': 'evaluation_results_results_jihbr_usmle-llama8b-qlora_results_2025-07-19T19-24-00.164953.json'
        },
        'Answer Completion': {
            'Baseline': 'evaluation_results_results_meta-llama_Meta-Llama-3-8B_results_2025-07-19T20-04-07.914896.json',
            'QLORA': 'evaluation_results_results_jihbr_usmle-llama8b-qlora_results_2025-07-19T19-52-24.700413.json'
        }
    }
    create_plot('Generative Evaluation: Baseline vs. QLORA', generative_files, df, 'generative_evaluation_comparison.png', 'viridis')


### Run code

In [4]:
generate_time_analysis_plots(master_df, gcs_output_path=GCS_OUTPUT_PATH)
generate_comparison_plots(master_df, gcs_output_path=GCS_OUTPUT_PATH)


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x='total_minutes', y='short_model_name', data=df_with_letters.sort_values('total_minutes', ascending=False), palette='viridis', ax=axes[0])

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x='total_minutes', y='short_model_name', data=df_without_letters.sort_values('total_minutes', ascending=False), palette='viridis', ax=axes[1])
