# Parity Accuracy Analysis: Llama-8B vs Llama-70B

This notebook analyzes the parity accuracy results for Llama-8B and Llama-70B models using bar plots with error bars, following the same methodology as the parity-notebook.ipynb.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import re
import scienceplots

In [None]:
# Configure plotting style for ICLR paper format with 12pt font
plt.style.use(['science', 'ieee'])  # Enables LaTeX + clean scientific styling

# LaTeX font settings - 12pt for ICLR format
plt.rcParams.update({
    'text.usetex': True,
    'font.family': 'serif',
    'font.size': 12,
    'axes.titlesize': 13,
    'axes.labelsize': 12,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11
})

In [None]:
# Read the CSV file containing llama-8B and llama-70B parity accuracy data
df = pd.read_csv('parity_accuracy_llama70B.csv')

# Display basic information about the data
print("Dataset shape:", df.shape)
print("\nColumn names:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head())

In [None]:
# Filter columns - remove MIN, MAX, and step columns
filtered_df = df[[col for col in df.columns if all(x not in col for x in ['MIN', 'MAX', '_step'])]]
print("Filtered columns:")
print(filtered_df.columns.tolist())

In [None]:
def extract_hyperparams(col_name):
    """Extract hyperparameters from column name"""
    if 'maj-voting' in col_name:
        match = re.search(r'agents(\d+)', col_name)
        return int(match.group(1)) if match else None
    elif 'coa' in col_name:
        match = re.search(r'chunk(\d+)', col_name)
        return int(match.group(1)) if match else None
    elif 'prefix-sum' in col_name:
        match = re.search(r'b(\d+)', col_name)
        return int(match.group(1)) if match else None
    return None

def extract_model_info(col_name):
    """Extract model type (8B or 70B) from column name"""
    if 'llama8B' in col_name:
        return '8B'
    elif 'llama70B' in col_name:
        return '70B'
    return None

In [None]:
def create_parity_bar_plot(model_type='8B', title_suffix=''):
    """
    Create bar plot with error bars for parity accuracy.
    Enhanced for academic paper presentation.
    """
    
    # Filter columns for the specified model type
    model_cols = [col for col in filtered_df.columns 
                  if col != 'sequence_length' and f'llama{model_type}' in col]
    
    methods = ['prefix-sum', 'maj-voting', 'coa']
    sequence_lengths = filtered_df['sequence_length']
    best_accs = {method: [] for method in methods}
    error_bars = {method: [] for method in methods}
    
    for method in methods:
        method_cols = [col for col in model_cols if method in col and 'avg_accuracy' in col]
        
        for _, row in filtered_df.iterrows():
            # Find best accuracy for this method at this sequence length
            method_accs = [row[col] for col in method_cols if pd.notna(row[col])]
            
            if method_accs:
                best_acc = max(method_accs)
                best_accs[method].append(best_acc)
                
                # Calculate standard error: sqrt(p * (1-p) / n) where n=100 runs
                std_error = np.sqrt(best_acc * (1 - best_acc) / 100)
                error_bars[method].append(std_error)
            else:
                best_accs[method].append(0)
                error_bars[method].append(0)
    
    # Enhanced plotting for academic papers
    bar_width = 0.25
    x_pos = range(len(sequence_lengths))
    
    fig, ax = plt.subplots(figsize=(4.2, 3.4))  # Reduced height back to compact size
    # Enhanced colors for better print quality and accessibility
    colors = ['#2E5EAA', '#2E8B57', '#B22222']
    
    bars = []
    for i, method in enumerate(methods):
        # Change "coa" to "CoA" in label
        if method == 'coa':
            label = 'CoA'
        else:
            label = method.replace('-', ' ').title()
            
        bar_container = ax.bar(
            [p + i * bar_width for p in x_pos],
            best_accs[method],
            width=bar_width,
            label=label,
            color=colors[i],
            yerr=error_bars[method],
            capsize=3,  # Larger caps for visibility
            error_kw={'linewidth': 0.8, 'capthick': 0.8, 'ecolor': 'black'},
            edgecolor='black',  # Black edges for bars
            linewidth=0.5,
            alpha=0.9  # Slight transparency for better aesthetics
        )
        bars.append(bar_container)
    
    ax.set_xticks([p + bar_width for p in x_pos])
    ax.set_xticklabels(sequence_lengths)
    ax.set_xlabel(r'\textbf{Sequence Length}')
    ax.set_ylabel(r'\textbf{Accuracy}')
    ax.set_title(r'\textbf{' + f'Llama-{model_type}: Parity Accuracy' + '}', pad=15)
    
    # Enhanced legend outside plot at bottom with horizontal orientation and optimized spacing
    legend = ax.legend(frameon=True, loc='upper center', bbox_to_anchor=(0.5, -0.2), 
                      ncol=3, fontsize=8, fancybox=True, shadow=True, framealpha=0.95,
                      edgecolor='black', facecolor='white')
    legend.get_frame().set_linewidth(0.8)
    
    # Better grid styling
    ax.grid(True, axis='y', linestyle='--', linewidth=0.6, alpha=0.6, color='gray')
    ax.set_axisbelow(True)  # Grid behind bars
    
    # Clean up spines for professional look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.8)
    ax.spines['bottom'].set_linewidth(0.8)
    
    # Better y-axis formatting
    ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
    ax.set_ylim(0, 1.05)  # Ensure full accuracy range is visible
    
    # Optimized layout for compact size
    fig.tight_layout()
    fig.subplots_adjust(bottom=0.25)  # Adjusted for compact figure
    
    # High-quality output for publication
    filename = f"parity_accuracy_llama{model_type.lower()}.pdf"
    fig.savefig(filename, bbox_inches='tight', dpi=300, facecolor='white')
    print(f"Saved plot as {filename}")
    plt.show()
    plt.close()
    
    return best_accs, error_bars

# Create plots for both models
print("Creating bar plot for Llama-8B...")
#best_accs_8b, error_bars_8b = create_parity_bar_plot('8B')

print("\nCreating bar plot for Llama-70B...")
best_accs_70b, error_bars_70b = create_parity_bar_plot('70B')

In [None]:
# Extract and display best hyperparameters for each model
def analyze_best_hyperparams(model_type='8B'):
    """
    Extract best hyperparameter values for each agent type and sequence length for specified model
    """
    print(f"\n=== LLAMA-{model_type} ANALYSIS ===")
    print("Best hyperparameters for each agent type and sequence length:")
    print("=" * 60)
    
    # Filter columns for the specified model type
    model_cols = [col for col in filtered_df.columns 
                  if col != 'sequence_length' and f'llama{model_type}' in col]
    
    # Create results dictionary
    best_hyperparams = {
        'maj-voting': {},
        'coa': {},  
        'prefix-sum': {}
    }
    
    for method in ['maj-voting', 'coa', 'prefix-sum']:
        method_cols = [col for col in model_cols if method in col and 'avg_accuracy' in col]
        
        for seq_len in filtered_df['sequence_length']:
            row_data = filtered_df[filtered_df['sequence_length'] == seq_len].iloc[0]
            
            # Get accuracies for this method and sequence length
            method_accs = {}
            for col in method_cols:
                acc = row_data[col]
                if pd.notna(acc):  # Only include non-NaN values
                    hyperparam = extract_hyperparams(col)
                    if hyperparam is not None:
                        method_accs[hyperparam] = acc
            
            # Find best hyperparameter
            if method_accs:
                best_hyperparam = max(method_accs.keys(), key=lambda k: method_accs[k])
                best_hyperparams[method][seq_len] = {
                    'hyperparam': best_hyperparam,
                    'accuracy': method_accs[best_hyperparam]
                }
    
    # Display results
    for method in ['maj-voting', 'coa', 'prefix-sum']:
        if method == 'maj-voting':
            method_display = "MAJ VOTING"
            param_name = "num_agents"
        elif method == 'coa':
            method_display = "CoA"
            param_name = "chunk_size"
        else:  # prefix-sum
            method_display = "PREFIX SUM"
            param_name = "branching_factor"
            
        print(f"\n{method_display}:")
        print("-" * 40)
        
        for seq_len in sorted(best_hyperparams[method].keys()):
            result = best_hyperparams[method][seq_len]
            print(f"Seq length {seq_len:3d}: {param_name}={result['hyperparam']:2d}, accuracy={result['accuracy']:.3f}")
    
    return best_hyperparams

# Analyze both models
best_hyperparams_8b = analyze_best_hyperparams('8B')
best_hyperparams_70b = analyze_best_hyperparams('70B')

In [None]:
# Create side-by-side comparison plot
def create_model_comparison_plot():
    """
    Create a side-by-side comparison of both models showing best accuracy for each method.
    Enhanced for academic paper presentation.
    """
    
    methods = ['prefix-sum', 'maj-voting', 'coa']
    sequence_lengths = filtered_df['sequence_length']
    
    # Collect best accuracies for both models
    model_data = {}
    
    for model_type in ['8B', '70B']:
        model_cols = [col for col in filtered_df.columns 
                      if col != 'sequence_length' and f'llama{model_type}' in col]
        
        best_accs = {method: [] for method in methods}
        
        for method in methods:
            method_cols = [col for col in model_cols if method in col and 'avg_accuracy' in col]
            
            for _, row in filtered_df.iterrows():
                method_accs = [row[col] for col in method_cols if pd.notna(row[col])]
                
                if method_accs:
                    best_acc = max(method_accs)
                    best_accs[method].append(best_acc)
                else:
                    best_accs[method].append(0)
        
        model_data[model_type] = best_accs
    
    # Enhanced comparison plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.5, 4.5), sharey=True)  # Increased height for legend
    colors = ['#2E5EAA', '#2E8B57', '#B22222']  # Enhanced colors
    bar_width = 0.25
    x_pos = range(len(sequence_lengths))
    
    # Plot Llama-8B
    for i, method in enumerate(methods):
        # Change "coa" to "CoA" in label
        if method == 'coa':
            label = 'CoA'
        else:
            label = method.replace('-', ' ').title()
            
        ax1.bar(
            [p + i * bar_width for p in x_pos],
            model_data['8B'][method],
            width=bar_width,
            label=label,
            color=colors[i],
            edgecolor='black',
            linewidth=0.5,
            alpha=0.9
        )
    
    ax1.set_xticks([p + bar_width for p in x_pos])
    ax1.set_xticklabels(sequence_lengths)
    ax1.set_xlabel(r'\textbf{Sequence Length}')
    ax1.set_ylabel(r'\textbf{Accuracy}')
    ax1.set_title(r'\textbf{Llama-8B}', pad=15)
    
    ax1.grid(True, axis='y', linestyle='--', linewidth=0.6, alpha=0.6, color='gray')
    ax1.set_axisbelow(True)
    
    # Plot Llama-70B
    for i, method in enumerate(methods):
        # Change "coa" to "CoA" in label
        if method == 'coa':
            label = 'CoA'
        else:
            label = method.replace('-', ' ').title()
            
        ax2.bar(
            [p + i * bar_width for p in x_pos],
            model_data['70B'][method],
            width=bar_width,
            label=label,
            color=colors[i],
            edgecolor='black',
            linewidth=0.5,
            alpha=0.9
        )
    
    ax2.set_xticks([p + bar_width for p in x_pos])
    ax2.set_xticklabels(sequence_lengths)
    ax2.set_xlabel(r'\textbf{Sequence Length}')
    ax2.set_title(r'\textbf{Llama-70B}', pad=15)
    
    ax2.grid(True, axis='y', linestyle='--', linewidth=0.6, alpha=0.6, color='gray')
    ax2.set_axisbelow(True)
    
    # Clean up spines for both plots
    for ax in [ax1, ax2]:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(0.8)
        ax.spines['bottom'].set_linewidth(0.8)
        ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
        ax.set_ylim(0, 1.05)
    
    # Add shared legend at bottom of figure with more spacing
    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.08), 
              ncol=3, fontsize=11, fancybox=True, shadow=True, framealpha=0.95,
              edgecolor='black', facecolor='white')
    
    # Adjust layout to prevent overlap
    fig.tight_layout()
    fig.subplots_adjust(bottom=0.18)  # Add extra space at bottom for legend
    
    fig.savefig("parity_accuracy_model_comparison.pdf", bbox_inches='tight', dpi=300, facecolor='white')
    print("Saved comparison plot as parity_accuracy_model_comparison.pdf")
    plt.show()
    plt.close()

create_model_comparison_plot()

In [None]:
# Summary statistics and comparison
def print_summary_comparison():
    """
    Print summary statistics comparing the two models
    """
    print("\n" + "="*80)
    print("SUMMARY COMPARISON: LLAMA-8B vs LLAMA-70B")
    print("="*80)
    
    methods = ['prefix-sum', 'maj-voting', 'coa']
    
    for method in methods:
        if method == 'prefix-sum':
            method_display = "PREFIX SUM"
        elif method == 'maj-voting':
            method_display = "MAJ VOTING"
        else:  # coa
            method_display = "CoA"
            
        print(f"\n{method_display} METHOD:")
        print("-" * 50)
        
        # Calculate average accuracy across all sequence lengths
        avg_8b = np.mean(best_accs_8b[method])
        avg_70b = np.mean(best_accs_70b[method])
        
        print(f"Average accuracy - Llama-8B:  {avg_8b:.3f}")
        print(f"Average accuracy - Llama-70B: {avg_70b:.3f}")
        print(f"Improvement (70B vs 8B): {avg_70b - avg_8b:+.3f} ({((avg_70b/avg_8b - 1) * 100):+.1f}%)")
    
    # Overall comparison
    overall_8b = np.mean([np.mean(best_accs_8b[method]) for method in methods])
    overall_70b = np.mean([np.mean(best_accs_70b[method]) for method in methods])
    
    print(f"\nOVERALL PERFORMANCE:")
    print("-" * 50)
    print(f"Overall average - Llama-8B:  {overall_8b:.3f}")
    print(f"Overall average - Llama-70B: {overall_70b:.3f}")
    print(f"Overall improvement (70B vs 8B): {overall_70b - overall_8b:+.3f} ({((overall_70b/overall_8b - 1) * 100):+.1f}%)")

print_summary_comparison()