# Computation vs Communication Tradeoff Analysis: Llama-8B vs Llama-70B

This notebook analyzes the computation vs communication tradeoff for Llama-8B and Llama-70B models using the same methodology as the pareto-notebook.ipynb.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
import scienceplots
import math
import re

In [None]:
# Configure plotting style for ICLR paper format with 12pt font
plt.style.use(['science', 'ieee'])  # Enables LaTeX + clean scientific styling

# LaTeX font settings - 12pt for ICLR format
plt.rcParams.update({
    'text.usetex': True,
    'font.family': 'serif',
    'font.size': 12,
    'axes.titlesize': 13,
    'axes.labelsize': 12,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11
})

In [None]:
# Read the CSV file containing llama-8B and llama-70B data
df = pd.read_csv('data/avg_completion_tokens_llama8B_70B.csv')

# Display basic information about the data
print("Dataset shape:", df.shape)
print("\nColumn names:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head())

In [None]:
# Helper functions from the original pareto notebook
def compute_num_agents(sequence_length, branching_factor):
    """
    Compute total number of agents in PrefixSumAgents hierarchical structure.
    
    Args:
        sequence_length: Length of input sequence (e.g., 32, 64, 128)
        branching_factor: Branching factor b (number of inputs each manager processes)
    
    Returns:
        int: Total number of manager agents needed
    """
    if branching_factor <= 1:
        return sequence_length  # Each element needs its own agent
    
    return math.ceil((sequence_length - 1) / (branching_factor - 1))

def compute_total_communication(sequence_length, branching_factor):
    """
    Compute total communication (number of edges) in PrefixSumAgents b-ary tree.
    
    Args:
        sequence_length: Length of input sequence (e.g., 32, 64, 128)  
        branching_factor: Branching factor b (number of inputs each manager processes)
    
    Returns:
        int: Total number of edges (communications) in the tree
    """
    if branching_factor <= 1:
        return sequence_length - 1  # Linear chain
    
    total_edges = 0
    current_level_size = sequence_length
    
    # Simulate the hierarchical processing
    while current_level_size > 1:
        # Number of managers at next level
        next_level_size = math.ceil(current_level_size / branching_factor)
        
        # Each manager receives up to b inputs, but we need to count actual edges
        # The number of edges from current level to next level equals current_level_size
        total_edges += current_level_size
        
        current_level_size = next_level_size
    
    return total_edges

def extract_sequence_info(column_name):
    """Extract sequence length from column name."""
    # Example: 'pareto_prefixsum_seq128_b1-6_llama70B - avg_completion_tokens'
    seq_match = re.search(r'seq(\d+)', column_name)
    sequence_length = int(seq_match.group(1)) if seq_match else None
    return sequence_length

def extract_model_info(column_name):
    """Extract model type (8B or 70B) from column name."""
    if 'llama8B' in column_name:
        return '8B'
    elif 'llama70B' in column_name:
        return '70B'
    return None

In [None]:
# Filter out MIN/MAX columns and step columns, keeping only the main metrics
def filter_columns(df):
    return df[[col for col in df.columns if all(x not in col for x in ['MIN', 'MAX', '_step'])]]

df_filtered = filter_columns(df)
print("Filtered columns:")
print(df_filtered.columns.tolist())

In [None]:
def plot_computation_vs_communication(model_type='8B', title_suffix=''):
    """
    Plot Total Computation (tokens) vs Total Communication (edges) for specified model.
    Updated to remove accuracy shading and change y-axis to Computation Depth.
    """
    
    plt.figure(figsize=(5.5, 3.2))  # Wider to accommodate legend on the right
    
    # Extract columns for the specified model type
    model_cols = [col for col in df_filtered.columns 
                  if col != 'branching_factor' and f'llama{model_type}' in col]
    
    # Extract unique sequence lengths for this model, filtering to keep only N>=64
    seq_lengths = set()
    for col in model_cols:
        seq_len = extract_sequence_info(col)
        if seq_len and seq_len >= 64:  # Filter to keep only N>=64
            seq_lengths.add(seq_len)
    
    seq_lengths = sorted(seq_lengths)
    # Enhanced colors for better print quality and accessibility
    colors = ['#2E5EAA', '#2E8B57', '#B22222', '#FF8C00']
    
    # Create a subplot with space for legend on the right
    ax = plt.gca()
    
    for i, seq_len in enumerate(seq_lengths):
        # Find the column for this sequence length and model
        token_col = None
        for col in model_cols:
            if f'seq{seq_len}' in col:
                token_col = col
                break
        
        if token_col:
            # Filter out NaN values
            valid_mask = ~pd.isna(df_filtered[token_col])
            valid_data = df_filtered[valid_mask]
            
            if len(valid_data) > 0:
                # Compute total communication (edges)
                total_communications = valid_data['branching_factor'].apply(
                    lambda b: compute_total_communication(seq_len, b)
                )
                
                # Use completion tokens as total computation
                total_computation = valid_data[token_col]
                
                # Get the base color for this sequence length
                base_color = colors[i % len(colors)]
                
                # Plot the connecting line
                plt.plot(total_communications, total_computation, 
                        linewidth=1.5,
                        color=base_color,
                        linestyle='-',
                        alpha=0.8,
                        label=f'$N={seq_len}$',
                        zorder=3)
                
                # Plot points without accuracy-based shading
                plt.scatter(total_communications, total_computation,
                          color=base_color,
                          alpha=0.8,
                          s=60,
                          edgecolors='white',
                          linewidth=0.5,
                          zorder=5)
    
    plt.xlabel(r'\textbf{Total Communication (Edges)}')
    plt.ylabel(r'\textbf{Computation Depth}')
    plt.title(r'\textbf{' + f'Llama-{model_type}: Computation vs Communication' + '}', pad=15)
    
    # Enhanced legend with sequence lengths
    legend1 = plt.legend(frameon=True, loc='upper left', fontsize=9, 
                        fancybox=True, shadow=True, framealpha=0.95,
                        edgecolor='black', facecolor='white')
    legend1.get_frame().set_linewidth(0.8)
    
    # Better grid styling
    plt.grid(True, linestyle='--', linewidth=0.6, alpha=0.6, color='gray')
    
    # Clean up spines for professional look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.8)
    ax.spines['bottom'].set_linewidth(0.8)
    
    # Better tick formatting
    ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True, nbins=6))
    ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
    
    plt.tight_layout()
    
    # High-quality output for publication
    filename = f"figures/computation_vs_communication_llama{model_type.lower()}.pdf"
    plt.savefig(filename, bbox_inches='tight', dpi=300, facecolor='white')
    print(f"Saved plot as {filename}")
    plt.show()

# Create plots for both models
plot_computation_vs_communication('8B')
plot_computation_vs_communication('70B')

In [None]:
# Display some example calculations for verification (N>=64 only)
print("Example calculations for verification:")
print("\nLlama-8B data:")
print("Seq Length | Branching Factor | Total Communication | Avg Completion Tokens")
print("-----------|------------------|--------------------|-----------------------")

# Find columns for 8B model
model_8b_cols = [col for col in df_filtered.columns if 'llama8B' in col]

for col in model_8b_cols:
    seq_len = extract_sequence_info(col)
    if seq_len and seq_len >= 64:  # Filter to keep only N>=64
        for _, row in df_filtered.iterrows():
            if pd.notna(row[col]):
                b = row['branching_factor']
                comm = compute_total_communication(seq_len, b)
                tokens = row[col]
                print(f"    {seq_len:3d}    |        {b:2d}        |        {comm:4d}        |       {tokens:.1f}")
        print()

print("\nLlama-70B data:")
print("Seq Length | Branching Factor | Total Communication | Avg Completion Tokens")
print("-----------|------------------|--------------------|-----------------------")

# Find columns for 70B model
model_70b_cols = [col for col in df_filtered.columns if 'llama70B' in col]

for col in model_70b_cols:
    seq_len = extract_sequence_info(col)
    if seq_len and seq_len >= 64:  # Filter to keep only N>=64
        for _, row in df_filtered.iterrows():
            if pd.notna(row[col]):
                b = row['branching_factor']
                comm = compute_total_communication(seq_len, b)
                tokens = row[col]
                print(f"    {seq_len:3d}    |        {b:2d}        |        {comm:4d}        |       {tokens:.1f}")
        print()

In [None]:
# Additional analysis: Compare models side by side for same sequence length
def plot_model_comparison(seq_length=128):
    """
    Plot comparison between 8B and 70B models for a specific sequence length.
    Updated to remove accuracy shading and change y-axis to Computation Depth.
    """
    
    plt.figure(figsize=(5.5, 3.2))  # Wider to accommodate legend on the right
    
    colors = ['#2E5EAA', '#B22222']  # Enhanced colors for better contrast
    models = ['8B', '70B']
    
    for i, model_type in enumerate(models):
        # Find the column for this model and sequence length
        token_col = None
        for col in df_filtered.columns:
            if f'seq{seq_length}' in col and f'llama{model_type}' in col:
                token_col = col
                break
        
        if token_col:
            # Filter out NaN values
            valid_mask = ~pd.isna(df_filtered[token_col])
            valid_data = df_filtered[valid_mask]
            
            if len(valid_data) > 0:
                # Compute total communication (edges)
                total_communications = valid_data['branching_factor'].apply(
                    lambda b: compute_total_communication(seq_length, b)
                )
                
                # Use completion tokens as total computation
                total_computation = valid_data[token_col]
                
                # Get the base color for this model
                base_color = colors[i]
                
                # Plot the connecting line first
                plt.plot(total_communications, total_computation, 
                        linewidth=2.0, 
                        color=base_color,
                        linestyle='-',
                        alpha=0.8,
                        label=f'Llama-{model_type}',
                        zorder=3)
                
                # Plot points without accuracy-based shading
                plt.scatter(total_communications, total_computation,
                          color=base_color,
                          alpha=0.8,
                          s=80,
                          edgecolors='white',
                          linewidth=1.0,
                          zorder=5)
    
    plt.xlabel(r'\textbf{Total Communication (Edges)}')
    plt.ylabel(r'\textbf{Computation Depth}')
    plt.title(r'\textbf{' + f'Model Comparison (N={seq_length})' + '}', pad=15)
    
    # Enhanced legend
    legend = plt.legend(frameon=True, loc='upper left', fontsize=10,
                       fancybox=True, shadow=True, framealpha=0.95,
                       edgecolor='black', facecolor='white')
    legend.get_frame().set_linewidth(0.8)
    
    plt.grid(True, linestyle='--', linewidth=0.6, alpha=0.6, color='gray')
    
    # Clean up spines
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.8)
    ax.spines['bottom'].set_linewidth(0.8)
    
    # Better tick formatting
    ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True, nbins=6))
    ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
    
    plt.tight_layout()
    
    # High-quality output
    filename = f"figures/model_comparison_seq{seq_length}.pdf"
    plt.savefig(filename, bbox_inches='tight', dpi=300, facecolor='white')
    print(f"Saved comparison plot as {filename}")
    plt.show()

# Create comparison plots for different sequence lengths (N>=64 only)
for seq_len in [64, 128, 256]:  # Removed 32, keeping only N>=64
    # Check if both models have data for this sequence length
    has_8b = any(f'seq{seq_len}' in col and 'llama8B' in col for col in df_filtered.columns)
    has_70b = any(f'seq{seq_len}' in col and 'llama70B' in col for col in df_filtered.columns)
    
    if has_8b and has_70b:
        plot_model_comparison(seq_len)