# Permutations Task: Llama-8B vs Llama-70B Comparison

This notebook compares the performance of different agent types on the permutations task between Llama-8B and Llama-70B models, plotting exact match and accuracy as a function of number of swaps with line plots and shaded error regions.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import scienceplots

# Configure plotting style
plt.style.use(['science', 'ieee'])

# LaTeX font settings for publication quality
plt.rcParams.update({
    'text.usetex': True,
    'font.family': 'serif',
    'font.size': 12,
    'axes.titlesize': 13,
    'axes.labelsize': 12,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11
})

In [None]:
# Load the datasets
se_acc_df = pd.read_csv('data/se_acc_permutation_llama8B_70B.csv')
se_em_df = pd.read_csv('data/se_em_permutation_llama8B_70B.csv')
avg_em_df = pd.read_csv('data/avg_em_permutation_llama8B_70B.csv')
avg_acc_df = pd.read_csv('data/avg_acc_permutation_llama8B_70B.csv')

print("Dataset shapes:")
print(f"SE Accuracy: {se_acc_df.shape}")
print(f"SE Exact Match: {se_em_df.shape}")
print(f"Avg Exact Match: {avg_em_df.shape}")
print(f"Avg Accuracy: {avg_acc_df.shape}")

print("\nExample columns from SE Accuracy:")
print([col for col in se_acc_df.columns if 'se_element_accuracy' in col][:3])
print("\nExample columns from Avg Exact Match:")
print([col for col in avg_em_df.columns if 'avg_exact_match' in col][:3])

In [None]:
def extract_method_and_model(col_name):
    """Extract method type and model size from column name"""
    # Extract method
    if 'prefix-sum' in col_name:
        method = 'prefix-sum'
    elif 'maj-voting' in col_name:
        method = 'maj-voting'
    elif 'coa' in col_name:
        method = 'coa'
    else:
        return None, None
    
    # Extract model size
    if 'llama8B' in col_name:
        model = '8B'
    elif 'llama70B' in col_name:
        model = '70B'
    else:
        return None, None
    
    return method, model

def extract_hyperparams(col_name):
    """Extract hyperparameters from column name"""
    if 'maj-voting' in col_name:
        match = re.search(r'agents(\d+)', col_name)
        return int(match.group(1)) if match else None
    elif 'coa' in col_name:
        match = re.search(r'chunk(\d+)', col_name)
        return int(match.group(1)) if match else None
    elif 'prefix-sum' in col_name:
        match = re.search(r'b(\d+)', col_name)
        return int(match.group(1)) if match else None
    return None

def process_data_for_plotting(df, metric_col_key, num_swaps_col='num_swaps'):
    """Process dataframe to extract best performance for each method-model combination"""
    
    # Handle different column names for number of swaps
    if num_swaps_col not in df.columns:
        if 'Step' in df.columns:
            num_swaps_col = 'Step'
        else:
            print(f"Warning: No swaps column found. Available columns: {df.columns.tolist()[:5]}...")
            return {}
    
    # Filter columns to remove MIN/MAX/step columns
    filtered_df = df[[col for col in df.columns if all(x not in col for x in ['MIN', 'MAX', 'step']) and metric_col_key in col]]
    
    methods = ['prefix-sum', 'maj-voting', 'coa']
    models = ['8B', '70B']
    num_swaps = df[num_swaps_col]
    
    results = {}
    
    for method in methods:
        for model in models:
            key = f'{method}_{model}'
            results[key] = {'performance': [], 'hyperparams': [], 'std_err': []}
            
            # Find columns for this method-model combination
            method_model_cols = [col for col in filtered_df.columns 
                               if method in col and f'llama{model}' in col]
            
            for _, row in df.iterrows():
                # Find best performance for this method-model at this number of swaps
                method_perfs = {}
                for col in method_model_cols:
                    perf = row[col]
                    if pd.notna(perf) and perf != '':  # Only include non-NaN and non-empty values
                        try:
                            perf = float(perf)
                            hyperparam = extract_hyperparams(col)
                            if hyperparam is not None:
                                method_perfs[hyperparam] = perf
                        except (ValueError, TypeError):
                            continue
                
                if method_perfs:
                    best_hyperparam = max(method_perfs.keys(), key=lambda k: method_perfs[k])
                    best_perf = method_perfs[best_hyperparam]
                    results[key]['performance'].append(best_perf)
                    results[key]['hyperparams'].append(best_hyperparam)
                    # Calculate standard error assuming binomial distribution
                    std_err = np.sqrt(best_perf * (1 - best_perf) / 100) if best_perf > 0 else 0
                    results[key]['std_err'].append(std_err)
                else:
                    results[key]['performance'].append(np.nan)
                    results[key]['hyperparams'].append(None)
                    results[key]['std_err'].append(0)
    
    return results, num_swaps

In [None]:
# Process the data for plotting
avg_em_results, num_swaps = process_data_for_plotting(avg_em_df, 'avg_exact_match')
avg_acc_results, _ = process_data_for_plotting(avg_acc_df, 'avg_element_accuracy')

print("Number of swaps:", list(num_swaps))
print("\nAvailable method-model combinations:")
for key in avg_em_results.keys():
    non_nan_count = sum(1 for x in avg_em_results[key]['performance'] if not np.isnan(x))
    print(f"{key}: {non_nan_count} valid data points")

In [None]:
def plot_comparison_lines(em_results, acc_results, num_swaps):
    """Create line plots comparing 8B vs 70B for each method"""
    
    methods = ['prefix-sum', 'maj-voting', 'coa']
    method_names = {'prefix-sum': 'Prefix Sum', 'maj-voting': 'Majority Voting', 'coa': 'Chain of Agents'}
    
    # Colors for 8B and 70B models
    colors_8b = ['#4C72B0', '#55A868', '#C44E52']  # Blue, Green, Red for 8B
    colors_70b = ['#1f4788', '#3d7c47', '#8b2635']  # Darker versions for 70B
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    # Convert to numpy array for easier indexing
    x_vals = np.array(num_swaps)
    
    for i, method in enumerate(methods):
        # Exact Match plot (top row)
        ax_em = axes[0, i]
        
        # Plot 8B model
        key_8b = f'{method}_8B'
        if key_8b in em_results:
            y_vals_8b = np.array(em_results[key_8b]['performance'])
            std_err_8b = np.array(em_results[key_8b]['std_err'])
            
            # Filter out NaN values
            valid_mask = ~np.isnan(y_vals_8b)
            if np.any(valid_mask):
                x_valid = x_vals[valid_mask]
                y_valid = y_vals_8b[valid_mask]
                err_valid = std_err_8b[valid_mask]
                
                ax_em.plot(x_valid, y_valid, 'o-', color=colors_8b[i], 
                          linewidth=2, markersize=6, label='Llama-8B', alpha=0.8)
                ax_em.fill_between(x_valid, 
                                  np.maximum(0, y_valid - err_valid),
                                  np.minimum(1, y_valid + err_valid),
                                  color=colors_8b[i], alpha=0.2)
        
        # Plot 70B model
        key_70b = f'{method}_70B'
        if key_70b in em_results:
            y_vals_70b = np.array(em_results[key_70b]['performance'])
            std_err_70b = np.array(em_results[key_70b]['std_err'])
            
            # Filter out NaN values
            valid_mask = ~np.isnan(y_vals_70b)
            if np.any(valid_mask):
                x_valid = x_vals[valid_mask]
                y_valid = y_vals_70b[valid_mask]
                err_valid = std_err_70b[valid_mask]
                
                ax_em.plot(x_valid, y_valid, 's-', color=colors_70b[i], 
                          linewidth=2, markersize=6, label='Llama-70B', alpha=0.8)
                ax_em.fill_between(x_valid, 
                                  np.maximum(0, y_valid - err_valid),
                                  np.minimum(1, y_valid + err_valid),
                                  color=colors_70b[i], alpha=0.2)
        
        ax_em.set_title(f'{method_names[method]} - Exact Match', fontweight='bold')
        ax_em.set_xlabel('Number of Swaps')
        ax_em.set_ylabel('Exact Match Accuracy')
        ax_em.legend(frameon=False)
        ax_em.grid(True, linestyle='--', alpha=0.6)
        ax_em.set_ylim(0, 1)
        
        # Element Accuracy plot (bottom row)
        ax_acc = axes[1, i]
        
        # Plot 8B model
        if key_8b in acc_results:
            y_vals_8b = np.array(acc_results[key_8b]['performance'])
            std_err_8b = np.array(acc_results[key_8b]['std_err'])
            
            # Filter out NaN values
            valid_mask = ~np.isnan(y_vals_8b)
            if np.any(valid_mask):
                x_valid = x_vals[valid_mask]
                y_valid = y_vals_8b[valid_mask]
                err_valid = std_err_8b[valid_mask]
                
                ax_acc.plot(x_valid, y_valid, 'o-', color=colors_8b[i], 
                           linewidth=2, markersize=6, label='Llama-8B', alpha=0.8)
                ax_acc.fill_between(x_valid, 
                                   np.maximum(0, y_valid - err_valid),
                                   np.minimum(1, y_valid + err_valid),
                                   color=colors_8b[i], alpha=0.2)
        
        # Plot 70B model
        if key_70b in acc_results:
            y_vals_70b = np.array(acc_results[key_70b]['performance'])
            std_err_70b = np.array(acc_results[key_70b]['std_err'])
            
            # Filter out NaN values
            valid_mask = ~np.isnan(y_vals_70b)
            if np.any(valid_mask):
                x_valid = x_vals[valid_mask]
                y_valid = y_vals_70b[valid_mask]
                err_valid = std_err_70b[valid_mask]
                
                ax_acc.plot(x_valid, y_valid, 's-', color=colors_70b[i], 
                           linewidth=2, markersize=6, label='Llama-70B', alpha=0.8)
                ax_acc.fill_between(x_valid, 
                                   np.maximum(0, y_valid - err_valid),
                                   np.minimum(1, y_valid + err_valid),
                                   color=colors_70b[i], alpha=0.2)
        
        ax_acc.set_title(f'{method_names[method]} - Element Accuracy', fontweight='bold')
        ax_acc.set_xlabel('Number of Swaps')
        ax_acc.set_ylabel('Element Accuracy')
        ax_acc.legend(frameon=False)
        ax_acc.grid(True, linestyle='--', alpha=0.6)
        ax_acc.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig('permutations_comparison_8B_vs_70B.pdf', bbox_inches='tight', dpi=300)
    plt.show()

# Create the comparison plots
plot_comparison_lines(avg_em_results, avg_acc_results, num_swaps)

In [None]:
# Display summary statistics
print("Performance Summary by Method and Model")
print("=" * 60)

methods = ['prefix-sum', 'maj-voting', 'coa']
models = ['8B', '70B']

for method in methods:
    print(f"\n{method.upper().replace('-', ' ')}:")
    print("-" * 40)
    
    for model in models:
        key = f'{method}_{model}'
        
        if key in avg_em_results:
            em_perfs = [p for p in avg_em_results[key]['performance'] if not np.isnan(p)]
            acc_perfs = [p for p in avg_acc_results[key]['performance'] if not np.isnan(p)]
            
            if em_perfs and acc_perfs:
                print(f"  Llama-{model}:")
                print(f"    Exact Match  - Mean: {np.mean(em_perfs):.3f}, Max: {np.max(em_perfs):.3f}, Min: {np.min(em_perfs):.3f}")
                print(f"    Element Acc  - Mean: {np.mean(acc_perfs):.3f}, Max: {np.max(acc_perfs):.3f}, Min: {np.min(acc_perfs):.3f}")
            else:
                print(f"  Llama-{model}: No valid data")

In [None]:
# Create a simplified single plot showing overall comparison
def plot_overall_comparison():
    """Create a simplified comparison plot showing best performance across methods"""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Colors for each method
    method_colors = {'prefix-sum': '#4C72B0', 'maj-voting': '#55A868', 'coa': '#C44E52'}
    
    x_vals = np.array(num_swaps)
    
    # Plot exact match comparison
    for method in ['prefix-sum', 'maj-voting', 'coa']:
        for model, linestyle, alpha in [('8B', '-', 0.7), ('70B', '--', 1.0)]:
            key = f'{method}_{model}'
            if key in avg_em_results:
                y_vals = np.array(avg_em_results[key]['performance'])
                valid_mask = ~np.isnan(y_vals)
                
                if np.any(valid_mask):
                    x_valid = x_vals[valid_mask]
                    y_valid = y_vals[valid_mask]
                    
                    label = f'{method.replace("-", " ").title()} ({model})'
                    ax1.plot(x_valid, y_valid, linestyle=linestyle, 
                            color=method_colors[method], linewidth=2, 
                            alpha=alpha, label=label, marker='o' if model=='8B' else 's')
    
    ax1.set_title('Exact Match Accuracy Comparison', fontweight='bold')
    ax1.set_xlabel('Number of Swaps')
    ax1.set_ylabel('Exact Match Accuracy')
    ax1.legend(frameon=False, fontsize=9)
    ax1.grid(True, linestyle='--', alpha=0.6)
    ax1.set_ylim(0, 1)
    
    # Plot element accuracy comparison
    for method in ['prefix-sum', 'maj-voting', 'coa']:
        for model, linestyle, alpha in [('8B', '-', 0.7), ('70B', '--', 1.0)]:
            key = f'{method}_{model}'
            if key in avg_acc_results:
                y_vals = np.array(avg_acc_results[key]['performance'])
                valid_mask = ~np.isnan(y_vals)
                
                if np.any(valid_mask):
                    x_valid = x_vals[valid_mask]
                    y_valid = y_vals[valid_mask]
                    
                    label = f'{method.replace("-", " ").title()} ({model})'
                    ax2.plot(x_valid, y_valid, linestyle=linestyle, 
                            color=method_colors[method], linewidth=2, 
                            alpha=alpha, label=label, marker='o' if model=='8B' else 's')
    
    ax2.set_title('Element Accuracy Comparison', fontweight='bold')
    ax2.set_xlabel('Number of Swaps')
    ax2.set_ylabel('Element Accuracy')
    ax2.legend(frameon=False, fontsize=9)
    ax2.grid(True, linestyle='--', alpha=0.6)
    ax2.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig('permutations_overall_comparison.pdf', bbox_inches='tight', dpi=300)
    plt.show()

plot_overall_comparison()

In [None]:
# Create plots comparing all three agent types for each model separately
def plot_agent_comparison_by_model():
    """Create 4 plots: EM and Accuracy for 8B, EM and Accuracy for 70B"""
    
    methods = ['prefix-sum', 'maj-voting', 'coa']
    method_names = {'prefix-sum': 'Prefix Sum', 'maj-voting': 'Majority Voting', 'coa': 'Chain of Agents'}
    
    # Colors for the three methods
    method_colors = ['#4C72B0', '#55A868', '#C44E52']  # Blue, Green, Red
    
    # Update font sizes to be bigger
    plt.rcParams.update({
        'font.size': 14,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12
    })
    
    x_vals = np.array(num_swaps)
    
    # Plot configurations for each model and metric combination
    plot_configs = [
        ('8B', avg_em_results, 'Exact Match'),
        ('8B', avg_acc_results, 'Element Accuracy'), 
        ('70B', avg_em_results, 'Exact Match'),
        ('70B', avg_acc_results, 'Element Accuracy')
    ]
    
    # Create each plot separately
    for model, results_dict, metric_name in plot_configs:
        fig, ax = plt.subplots(figsize=(6, 5))
        
        for i, method in enumerate(methods):
            key = f'{method}_{model}'
            
            if key in results_dict:
                y_vals = np.array(results_dict[key]['performance'])
                std_err = np.array(results_dict[key]['std_err'])
                
                # Filter out NaN values
                valid_mask = ~np.isnan(y_vals)
                
                if np.any(valid_mask):
                    x_valid = x_vals[valid_mask]
                    y_valid = y_vals[valid_mask]
                    err_valid = std_err[valid_mask]
                    
                    # Plot line with markers
                    ax.plot(x_valid, y_valid, 'o-', color=method_colors[i], 
                           linewidth=2.5, markersize=6, label=method_names[method], 
                           alpha=0.8, markeredgecolor='white', markeredgewidth=0.5)
                    
                    # Add shaded error region
                    ax.fill_between(x_valid, 
                                   np.maximum(0, y_valid - err_valid),
                                   np.minimum(1, y_valid + err_valid),
                                   color=method_colors[i], alpha=0.2)
        
        # Bold titles and axis labels
        ax.set_title(rf'\textbf{{Llama-{model}: {metric_name}}}')
        ax.set_xlabel(rf'\textbf{{Number of Swaps}}')
        ax.set_ylabel(rf'\textbf{{{metric_name}}}')
        
        # Use khop-notebook legend style
        legend = ax.legend(frameon=True, loc='upper right', 
                          fancybox=True, shadow=True, framealpha=0.95,
                          edgecolor='black', facecolor='white')
        legend.get_frame().set_linewidth(0.8)
        
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.set_ylim(0, 1)
        
        # Improve tick formatting
        ax.tick_params(labelsize=12)
        
        plt.tight_layout()
        
        # Save each figure separately to figures directory
        import os
        os.makedirs('figures', exist_ok=True)
        
        # Create filename
        metric_clean = metric_name.replace(' ', '_').lower()
        filename = f'figures/permutations_{model.lower()}_{metric_clean}.pdf'
        plt.savefig(filename, bbox_inches='tight', dpi=300)
        print(f"Saved: {filename}")
        
        plt.show()
        plt.close(fig)

# Create the agent comparison plots
plot_agent_comparison_by_model()