In [None]:
import pandas as pd
from typing import Optional, Union, List
from pathlib import Path

def read_parquet_file(file_path: Union[str, Path], columns: Optional[List[str]] = None) -> pd.DataFrame:
    """
    Read a parquet file into a pandas DataFrame.
    
    Args:
        file_path: Path to the parquet file
        columns: Optional list of columns to read
        
    Returns:
        DataFrame containing the parquet data
    """
    return pd.read_parquet(file_path, columns=columns)


file_path = "./quantization_ablation_model/Meta-Llama-3-8B-Instruct.Q2_K.gguf/llama_classification.parquet"
df = read_parquet_file(file_path)
print(f"Loaded dataframe with shape: {df.shape}")
df.head()


In [None]:
# TODO some classificaiton data I forgot to add the subject, so I'm going to add it now
input_data_file_path = "./input_datasets/classification_pairs.parquet"
input_data = read_parquet_file(input_data_file_path)
def add_subject_column(df: pd.DataFrame, input_data: pd.DataFrame) -> pd.DataFrame:
    """
    Add subject column to the dataframe by mapping from input_data.
    
    Args:
        df: Target dataframe to add subject column to
        input_data: Source dataframe containing subject information
        
    Returns:
        DataFrame with added subject column
    """
    # Create mapping from question to subject
    question_to_subject = dict(zip(input_data['question'], input_data['subject']))
    # Map questions to subjects
    df['subject'] = df['question'].map(question_to_subject)
    return df

df = add_subject_column(df, input_data)

In [None]:
df[df["id"] == 13594]["response"].values[0]

In [None]:
df.head()

In [None]:
def calculate_accuracy(df: pd.DataFrame, group_by: str = None) -> pd.DataFrame:
    if group_by:
        accuracy_by_group = df.groupby(group_by)['correct'].mean().reset_index()
        accuracy_by_group.columns = [group_by, 'accuracy']
        
        group_counts = df.groupby(group_by).size().reset_index(name='sample_count')
        accuracy_by_group = accuracy_by_group.merge(group_counts, on=group_by)
        
        return accuracy_by_group
    else:
        overall_accuracy = df['correct'].mean()
        return pd.DataFrame({'overall_accuracy': [overall_accuracy], 'sample_count': [len(df)]})

# Calculate overall accuracy
overall_accuracy = calculate_accuracy(df)
print(f"Overall accuracy: {overall_accuracy['overall_accuracy'].values[0]:.4f}")

# Calculate accuracy for only the answers that are in A, B, C, D
valid_responses_df = df[df['response'].isin(['A', 'B', 'C', 'D'])]
valid_responses_accuracy = calculate_accuracy(valid_responses_df)
print(f"\nAccuracy for A, B, C, D responses only: {valid_responses_accuracy['overall_accuracy'].values[0]:.4f}")
print(f"Sample count: {valid_responses_accuracy['sample_count'].values[0]}")

# Calculate accuracy by subject
subject_accuracy = calculate_accuracy(df, group_by='subject')
subject_accuracy = subject_accuracy.sort_values('accuracy', ascending=False)

print("\nAccuracy by subject:")
print(subject_accuracy)


In [None]:
import os
import glob
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional, Union

# Set the style for plots
plt.style.use('fivethirtyeight')
sns.set_context("talk")

def load_classification_files(directory: str = ".") -> Dict[str, pd.DataFrame]:
    """
    Load all classification result files from the specified directory.
    
    Args:
        directory: Directory to search for classification files
        
    Returns:
        Dictionary mapping model names to their dataframes
    """
    # Find all parquet files that might contain classification results
    classification_files = glob.glob(os.path.join(directory, "**", "*classification*.parquet"), recursive=True)
    
    model_dfs = {}
    for file_path in classification_files:
        try:
            # Extract model name from the file path
            model_name = re.search(r'quantization_ablation_([^/\\]+)', file_path)
            if model_name:
                model_name = model_name.group(1)
            else:
                model_name = os.path.basename(file_path).replace('.parquet', '')
            
            # Load the dataframe
            df = pd.read_parquet(file_path)
            
            # Add model name as a column
            df['model'] = model_name
            
            # Store in dictionary
            model_dfs[model_name] = df
            print(f"Loaded {model_name} with {len(df)} entries")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    
    return model_dfs

def combine_dataframes(model_dfs: Dict[str, pd.DataFrame]) -> pd.DataFrame:
    """
    Combine all model dataframes into a single dataframe.
    
    Args:
        model_dfs: Dictionary of model dataframes
        
    Returns:
        Combined dataframe
    """
    if not model_dfs:
        raise ValueError("No model dataframes provided")
    
    return pd.concat(model_dfs.values(), ignore_index=True)

def analyze_accuracy(combined_df: pd.DataFrame) -> None:
    """
    Analyze and plot accuracy metrics across models.
    
    Args:
        combined_df: Combined dataframe with all models
    """
    # Calculate accuracy for each model
    model_accuracy = combined_df.groupby('model')['correct'].mean().reset_index()
    model_accuracy = model_accuracy.sort_values('correct', ascending=False)
    
    # Plot overall accuracy by model
    plt.figure(figsize=(12, 8))
    sns.barplot(x='model', y='correct', data=model_accuracy)
    plt.title('Overall Accuracy by Model')
    plt.xlabel('Model')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    
    # Calculate accuracy by subject for each model
    subject_accuracy = combined_df.groupby(['model', 'subject'])['correct'].mean().reset_index()
    
    # Plot accuracy by subject for each model
    plt.figure(figsize=(15, 10))
    sns.barplot(x='subject', y='correct', hue='model', data=subject_accuracy)
    plt.title('Accuracy by Subject and Model')
    plt.xlabel('Subject')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=90)
    plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
    
    # Create a heatmap of accuracy by subject and model
    pivot_accuracy = subject_accuracy.pivot(index='subject', columns='model', values='correct')
    plt.figure(figsize=(14, 12))
    sns.heatmap(pivot_accuracy, annot=True, cmap='YlGnBu', fmt='.3f', linewidths=.5)
    plt.title('Accuracy Heatmap by Subject and Model')
    plt.tight_layout()
    plt.show()
    
    # Calculate valid response rate (A, B, C, D only)
    valid_responses = combined_df.groupby('model').apply(
        lambda x: x['response'].isin(['A', 'B', 'C', 'D']).mean()
    ).reset_index(name='valid_response_rate')
    
    # Plot valid response rate
    plt.figure(figsize=(12, 8))
    sns.barplot(x='model', y='valid_response_rate', data=valid_responses)
    plt.title('Valid Response Rate by Model (A, B, C, D only)')
    plt.xlabel('Model')
    plt.ylabel('Valid Response Rate')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

def analyze_timing(combined_df: pd.DataFrame) -> None:
    """
    Analyze and plot timing metrics across models.
    
    Args:
        combined_df: Combined dataframe with all models
    """
    # Check if timing columns exist
    timing_cols = ['prompt_time', 'response_time']
    if not all(col in combined_df.columns for col in timing_cols):
        print("Timing columns not found in the dataframe")
        return
    
    # Calculate total time
    combined_df['total_time'] = combined_df['prompt_time'] + combined_df['response_time']
    
    # Calculate average timing metrics for each model
    timing_metrics = combined_df.groupby('model')[timing_cols + ['total_time']].mean().reset_index()
    
    # Plot average timing metrics
    plt.figure(figsize=(14, 8))
    timing_metrics_melted = pd.melt(
        timing_metrics, 
        id_vars=['model'], 
        value_vars=timing_cols + ['total_time'],
        var_name='Timing Metric', 
        value_name='Time (seconds)'
    )
    sns.barplot(x='model', y='Time (seconds)', hue='Timing Metric', data=timing_metrics_melted)
    plt.title('Average Timing Metrics by Model')
    plt.xlabel('Model')
    plt.ylabel('Time (seconds)')
    plt.xticks(rotation=45, ha='right')
    plt.legend(title='Timing Metric')
    plt.tight_layout()
    plt.show()
    
    # Plot timing distribution with boxplots
    plt.figure(figsize=(14, 8))
    timing_data_melted = pd.melt(
        combined_df, 
        id_vars=['model'], 
        value_vars=timing_cols + ['total_time'],
        var_name='Timing Metric', 
        value_name='Time (seconds)'
    )
    sns.boxplot(x='model', y='Time (seconds)', hue='Timing Metric', data=timing_data_melted)
    plt.title('Timing Distribution by Model')
    plt.xlabel('Model')
    plt.ylabel('Time (seconds)')
    plt.xticks(rotation=45, ha='right')
    plt.legend(title='Timing Metric')
    plt.tight_layout()
    plt.show()
    
    # Calculate tokens per second for response generation
    if 'response_tokens' in combined_df.columns:
        combined_df['tokens_per_second'] = combined_df['response_tokens'] / combined_df['response_time']
        
        # Plot tokens per second by model
        plt.figure(figsize=(12, 8))
        sns.boxplot(x='model', y='tokens_per_second', data=combined_df)
        plt.title('Response Generation Speed (Tokens per Second) by Model')
        plt.xlabel('Model')
        plt.ylabel('Tokens per Second')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

def analyze_error_patterns(combined_df: pd.DataFrame) -> None:
    """
    Analyze error patterns across models.
    
    Args:
        combined_df: Combined dataframe with all models
    """
    # Filter for incorrect responses only
    incorrect_df = combined_df[combined_df['correct'] == False]
    
    # Count frequency of each incorrect response by model
    error_patterns = incorrect_df.groupby(['model', 'response']).size().reset_index(name='count')
    
    # Plot error patterns
    plt.figure(figsize=(14, 10))
    for model in error_patterns['model'].unique():
        model_errors = error_patterns[error_patterns['model'] == model]
        model_errors = model_errors.sort_values('count', ascending=False).head(10)
        
        plt.figure(figsize=(10, 6))
        sns.barplot(x='response', y='count', data=model_errors)
        plt.title(f'Top 10 Error Responses for {model}')
        plt.xlabel('Response')
        plt.ylabel('Count')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

# Main analysis function
def run_classification_analysis() -> None:
    """Run comprehensive analysis on classification results."""
    # Load all classification files
    model_dfs = load_classification_files()
    
    if not model_dfs:
        print("No classification files found.")
        return
    
    # Combine all dataframes
    combined_df = combine_dataframes(model_dfs)
    
    # Run analyses
    print("\n=== Accuracy Analysis ===")
    analyze_accuracy(combined_df)
    
    print("\n=== Timing Analysis ===")
    analyze_timing(combined_df)
    
    print("\n=== Error Pattern Analysis ===")
    analyze_error_patterns(combined_df)

# Run the analysis
run_classification_analysis()
