In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import itertools
import json

sns.set_theme(style="whitegrid")

In [None]:
# set working directory

os.chdir('/fast/pmayilvahanan/llm_line/code/llm_line/')

In [None]:
# extract 

# Function to extract model info from filename
def parse_model_info(filename):
    parts = filename.split('__')
    model_dict = {}
    for part in parts:
        if '=' in part:
            key, value = part.split('=')
            model_dict[key] = value
    return model_dict

# Read all json files and organize data
results_dir = 'results/'
data = []


for filename in os.listdir(results_dir):
    if filename.endswith('.json'):
        with open(os.path.join(results_dir, filename)) as f:
            result = json.load(f)
            model_info = parse_model_info(filename)
            
            # Only include models trained on The Pile or C4
            if 'dataset' in model_info and model_info['dataset'] in ['C4', 'ThePile']:
                # Check if model name contains any of our target architectures:
                # - Mamba/Mamba2 models
                # - HuggingFace models
                # - T5 models
                model_name = model_info['name'].lower()
                if any(name in model_name for name in ['mamba', 'mamba2', 'huggingface', 't5']):
                    data.append({
                        'dataset': model_info['dataset'],
                        'arch': model_info.get('arch', 'unknown'),
                        **{k: v.get('acc,none', None) for k, v in result.items() if isinstance(v, dict)}
                    })

df = pd.DataFrame(data)

In [None]:
df

In [None]:
def plot_metric_comparison(df, x_metric, y_metric, title=None):
    plt.figure(figsize=(10, 8))
    
    # Define markers for different architectures
    markers = {'GPT': 'o', 'Mamba': 's', 'T5': '^', 'Mamba2': 'D'}
    colors = {'C4': 'skyblue', 'ThePile': 'coral'}
    
    # First plot all points with different markers for architectures
    for dataset in colors.keys():
        dataset_points_x = []  # Collect all x points for this dataset
        dataset_points_y = []  # Collect all y points for this dataset
        
        for arch in markers.keys():
            mask = (df['dataset'] == dataset) & (df['arch'] == arch)
            if mask.any():
                data = df[mask]
                plt.scatter(data[x_metric], data[y_metric], 
                          label=f'{dataset}-{arch}',
                          marker=markers[arch],
                          c=colors[dataset],
                          alpha=0.7,
                          s=100)
                
                # Collect points for trendline calculation
                dataset_points_x.extend(data[x_metric].values)
                dataset_points_y.extend(data[y_metric].values)
        
        # Now plot trendline using all points for this dataset
        if dataset_points_x:  # If we have any points for this dataset
            z = np.polyfit(dataset_points_x, dataset_points_y, 1)
            p = np.poly1d(z)
            x_range = np.linspace(min(dataset_points_x), max(dataset_points_x), 100)
            plt.plot(x_range, p(x_range), c=colors[dataset], alpha=0.5)
            
            # Calculate R-squared using all points for this dataset
            correlation_matrix = np.corrcoef(dataset_points_x, dataset_points_y)
            r_squared = correlation_matrix[0,1]**2
            plt.text(np.mean(dataset_points_x), np.mean(dataset_points_y), 
                    f'R²={r_squared:.3f}', 
                    fontsize=8)

    plt.xlabel(x_metric)
    plt.ylabel(y_metric)
    plt.title(title or f'{y_metric} vs {x_metric}')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Example usage - you can create multiple plots for different metric pairs


In [None]:
plot_metric_comparison(df, 'arc_easy_0_shot', 'piqa_0_shot')
# 'openbookqa_0_shot'
# 'arc_challenge_0_shot'
# 'arc_easy_0_shot'
# 'hellaswag_0_shot'
# 'piqa_0_shot'
# 'winogrande_0_shot'

In [None]:
plot_metric_comparison(df, 'arc_easy_0_shot', 'hellaswag_0_shot')


In [None]:
plot_metric_comparison(df, 'arc_challenge_0_shot', 'hellaswag_0_shot')


In [None]:
plot_metric_comparison(df, 'openbookqa_0_shot', 'hellaswag_0_shot')


In [None]:
def plot_metric_grid(df, metrics=None):
    if metrics is None:
        metrics = [
            'openbookqa_0_shot',
            'arc_challenge_0_shot',
            'arc_easy_0_shot',
            'hellaswag_0_shot',
            'piqa_0_shot',
            'winogrande_0_shot'
        ]
    
    n = len(metrics)
    fig, axes = plt.subplots(n, n, figsize=(4*n, 4*n))
    
    # Define valid architecture-dataset combinations
    valid_combinations = {
        'T5': ['C4'],
        'Mamba': ['ThePile'],
        'Mamba2': ['ThePile'],
        'GPT': ['C4', 'ThePile']
    }
    
    markers = {'GPT': 'o', 'Mamba': 's', 'T5': '^', 'Mamba2': 'D'}
    colors = {'C4': 'skyblue', 'ThePile': 'coral'}
    
    for i, j in itertools.product(range(n), range(n)):
        ax = axes[i, j]
        
        # Hide upper triangle and set no ticks
        if i < j:
            ax.set_visible(False)
            continue
            
        x_metric = metrics[j]
        y_metric = metrics[i]
        
        # If on diagonal, show metric name
        if i == j:
            ax.text(0.5, 0.5, metrics[i].replace('_0_shot', ''),
                   ha='center', va='center', wrap=True)
            ax.set_xticks([])
            ax.set_yticks([])
            continue
            
        # Plot data for each dataset
        for dataset in colors.keys():
            dataset_points_x = []
            dataset_points_y = []
            
            for arch, valid_datasets in valid_combinations.items():
                if dataset in valid_datasets:
                    mask = (df['dataset'] == dataset) & (df['arch'] == arch)
                    if mask.any():
                        data = df[mask]
                        ax.scatter(data[x_metric], data[y_metric],
                                 marker=markers[arch],
                                 c=colors[dataset],
                                 alpha=0.7,
                                 s=50)
                        
                        dataset_points_x.extend(data[x_metric].values)
                        dataset_points_y.extend(data[y_metric].values)
            
            # Add trendline for the entire dataset
            if dataset_points_x:
                z = np.polyfit(dataset_points_x, dataset_points_y, 1)
                p = np.poly1d(z)
                x_range = np.linspace(min(dataset_points_x), max(dataset_points_x), 100)
                ax.plot(x_range, p(x_range), c=colors[dataset], alpha=0.5)
                
                # Add R-squared for the entire dataset
                correlation_matrix = np.corrcoef(dataset_points_x, dataset_points_y)
                r_squared = correlation_matrix[0,1]**2
                
                # Position R-squared near the end of the trendline
                x_pos = x_range[-1]
                y_pos = p(x_pos)
                ax.text(x_pos, y_pos, f'R²={r_squared:.3f}',
                       fontsize=8,
                       verticalalignment='bottom',
                       horizontalalignment='right',
                       color=colors[dataset])
        
        # Only show x/y labels on edges
        if i == n-1:
            ax.set_xlabel(x_metric.replace('_0_shot', ''), fontsize=8)
        if j == 0:
            ax.set_ylabel(y_metric.replace('_0_shot', ''), fontsize=8)
            
        ax.tick_params(axis='both', which='major', labelsize=6)
    
    # Add legend outside the plot
    legend_elements = []
    for arch, valid_datasets in valid_combinations.items():
        for dataset in valid_datasets:
            legend_elements.append(plt.Line2D([0], [0], marker=markers[arch],
                                           color=colors[dataset], 
                                           label=f'{dataset}-{arch}',
                                           linestyle='None', markersize=8))
    
    fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.98, 0.5))
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_metric_grid(df)

In [None]:
def plot_metric_grid(df, metrics=None, transform='probit'):
    """
    Plot metric grid with transformed axes but original tick labels
    Args:
        df: DataFrame with metrics
        metrics: List of metrics to plot
        transform: Either 'probit' or 'normal'
    """
    from scipy.stats import norm
    
    if metrics is None:
        metrics = [
            'openbookqa_0_shot',
            'arc_challenge_0_shot', 
            'arc_easy_0_shot',
            'hellaswag_0_shot',
            'piqa_0_shot',
            'winogrande_0_shot'
        ]
    
    n = len(metrics)
    fig, axes = plt.subplots(n, n, figsize=(4*n, 4*n))
    
    valid_combinations = {
        'T5': ['C4'],
        'Mamba': ['ThePile'],
        'Mamba2': ['ThePile'],
        'GPT': ['C4', 'ThePile']
    }
    
    markers = {'GPT': 'o', 'Mamba': 's', 'T5': '^', 'Mamba2': 'D'}
    colors = {'C4': 'skyblue', 'ThePile': 'coral'}
    
    for i, j in itertools.product(range(n), range(n)):
        ax = axes[i, j]
        
        if i < j:
            ax.set_visible(False)
            continue
            
        x_metric = metrics[j]
        y_metric = metrics[i]
        
        if i == j:
            ax.text(0.5, 0.5, metrics[i].replace('_0_shot', ''),
                   ha='center', va='center', wrap=True)
            ax.set_xticks([])
            ax.set_yticks([])
            continue
            
        # Create tick positions and labels based on data range
        x_min, x_max = df[x_metric].min(), df[x_metric].max()
        y_min, y_max = df[y_metric].min(), df[y_metric].max()
        
        x_ticks = np.linspace(x_min, x_max, 4)
        y_ticks = np.linspace(y_min, y_max, 4)
        x_transformed = norm.ppf(x_ticks)
        y_transformed = norm.ppf(y_ticks)
        
        for dataset in colors.keys():
            dataset_points_x = []
            dataset_points_y = []
            
            for arch, valid_datasets in valid_combinations.items():
                if dataset in valid_datasets:
                    mask = (df['dataset'] == dataset) & (df['arch'] == arch)
                    if mask.any():
                        data = df[mask]
                        
                        x_data = norm.ppf(data[x_metric])
                        y_data = norm.ppf(data[y_metric])
                        
                        ax.scatter(x_data, y_data,
                                 marker=markers[arch],
                                 c=colors[dataset],
                                 alpha=0.7,
                                 s=50)
                        
                        dataset_points_x.extend(x_data)
                        dataset_points_y.extend(y_data)
            
            if dataset_points_x:
                valid_mask = np.isfinite(dataset_points_x) & np.isfinite(dataset_points_y)
                valid_x = np.array(dataset_points_x)[valid_mask]
                valid_y = np.array(dataset_points_y)[valid_mask]
                
                if len(valid_x) > 1:
                    z = np.polyfit(valid_x, valid_y, 1)
                    p = np.poly1d(z)
                    x_range = np.linspace(min(valid_x), max(valid_x), 100)
                    ax.plot(x_range, p(x_range), c=colors[dataset], alpha=0.5)
                    
                    correlation_matrix = np.corrcoef(valid_x, valid_y)
                    r_squared = correlation_matrix[0,1]**2
                    
                    x_pos = x_range[-1]
                    y_pos = p(x_pos)
                    ax.text(x_pos, y_pos, f'R²={r_squared:.3f}',
                           fontsize=8,
                           verticalalignment='bottom',
                           horizontalalignment='right',
                           color=colors[dataset])
        
        # Set the transformed tick positions with original scale labels
        ax.set_xticks(x_transformed)
        ax.set_yticks(y_transformed)
        ax.set_xticklabels([f'{x:.2f}' for x in x_ticks])
        ax.set_yticklabels([f'{y:.2f}' for y in y_ticks])
        
        if i == n-1:
            ax.set_xlabel(x_metric.replace('_0_shot', ''), fontsize=8)
        if j == 0:
            ax.set_ylabel(y_metric.replace('_0_shot', ''), fontsize=8)
            
        ax.tick_params(axis='both', which='major', labelsize=6)
    
    legend_elements = []
    for arch, valid_datasets in valid_combinations.items():
        for dataset in valid_datasets:
            legend_elements.append(plt.Line2D([0], [0], marker=markers[arch],
                                           color=colors[dataset], 
                                           label=f'{dataset}-{arch}',
                                           linestyle='None', markersize=8))
    
    fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.98, 0.5))
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_metric_grid(df, transform='probit')
