In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import itertools
import scipy.stats as stats

sns.set_theme(style="whitegrid")

In [None]:
# set cwd
os.chdir('/lustre/fast/fast/pmayilvahanan/llm_line/code/llm_line')

In [None]:
def create_performance_scatter(csv_file, dataset1, dataset2, step_interval=1000, pretraining_data=None, min_step=0):
    """
    Create a scatter plot comparing performance on two datasets
    
    Args:
        csv_file: path to the CSV file containing the results
        dataset1: name of first dataset (x-axis)
        dataset2: name of second dataset (y-axis)
        step_interval: plot points at every step_interval steps (default: 1000)
        pretraining_data: list of specific pretraining datasets to include (default: None, includes all)
        min_step: minimum step to start plotting from (default: 0)
    """
    # Load the data
    df = pd.read_csv(csv_file)
    
    # Filter steps based on interval and minimum step
    df = df[(df['steps'] % step_interval == 0) & (df['steps'] >= min_step)]
    
    # Filter by specific pretraining datasets if provided
    if pretraining_data is not None:
        df = df[df['runname'].isin(pretraining_data)]
    
    # Construct column names
    x_col = f"{dataset1}/acc"
    y_col = f"{dataset2}/acc"
    
    # Check if columns exist
    if x_col not in df.columns or y_col not in df.columns:
        raise ValueError(f"Columns {x_col} or {y_col} not found in the data.")
    
    # Create figure and axis
    plt.figure(figsize=(12, 8))
    
    # Create scatter plot
    sns.scatterplot(data=df, x=x_col, y=y_col, hue='runname', alpha=0.6)
    
    # Dictionary to store R² values
    r2_values = {}
    
    # Add trend line for each runname and calculate R²
    for name in df['runname'].unique():
        subset = df[df['runname'] == name]
        
        # Calculate regression line
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            subset[x_col], subset[y_col]
        )
        r2 = r_value ** 2
        r2_values[name] = r2
        
        # Plot regression line
        sns.regplot(data=subset, x=x_col, y=y_col, 
                   scatter=False, 
                   label=f"{name} (R² = {r2:.3f})", 
                   ci=None)
    
    # Customize plot
    plt.title(f'Model Performance: {dataset1} vs {dataset2}\n(Steps interval: {step_interval}, Min step: {min_step})')
    plt.xlabel(f'{dataset1} Accuracy')
    plt.ylabel(f'{dataset2} Accuracy')
    
    # Adjust legend
    plt.legend(title='Pretraining Data', bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Print detailed fit statistics
    print("\nDetailed Fit Statistics:")
    print("-" * 50)
    for name in r2_values:
        subset = df[df['runname'] == name]
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            subset[x_col], subset[y_col]
        )
        print(f"\nPretraining Data: {name}")
        print(f"R² Score: {r2_values[name]:.3f}")
        print(f"Slope: {slope:.3f}")
        print(f"Intercept: {intercept:.3f}")
        print(f"P-value: {p_value:.3e}")
        print(f"Standard Error: {std_err:.3e}")
    
    # Adjust layout to prevent legend cutoff
    plt.tight_layout()
    
    # Show plot
    plt.show()

In [None]:
# Example usage
create_performance_scatter(
    'finweb_results.csv', 
    'hellaswag', 
    'mmlu', 
    step_interval=1000, 
    pretraining_data=['C4', 'The Pile', 'RefineWeb', 'RedPajama2'],
    min_step=50000
)
#create_performance_scatter('finweb_results.csv', 'arc', 'piqa', step_interval=10000)

In [None]:
# def create_scatterplot_grid(csv_file, datasets, step_interval=1000, pretraining_data=None, min_step=0):
#     """
#     Create a grid of scatter plots for each pair of downstream tasks.
    
#     Args:
#         csv_file: path to the CSV file containing the results
#         datasets: list of dataset names to compare
#         step_interval: plot points at every step_interval steps (default: 1000)
#         pretraining_data: list of specific pretraining datasets to include (default: None, includes all)
#         min_step: minimum step to start plotting from (default: 0)
#     """
#     num_datasets = len(datasets)
#     fig, axes = plt.subplots(num_datasets, num_datasets, figsize=(15, 15))
    
#     for i, j in itertools.combinations(range(num_datasets), 2):
#         dataset1 = datasets[i]
#         dataset2 = datasets[j]
        
#         ax = axes[i, j]
#         plt.sca(ax)  # Set current axis
        
#         # Create scatter plot for the current pair of datasets
#         create_performance_scatter(csv_file, dataset1, dataset2, step_interval, pretraining_data, min_step)
        
#         # Remove axis labels for cleaner grid
#         ax.set_xlabel('')
#         ax.set_ylabel('')
    
#     # Remove unused subplots (diagonal and lower triangle)
#     for i in range(num_datasets):
#         for j in range(num_datasets):
#             if i >= j:
#                 fig.delaxes(axes[i, j])
    
#     # Adjust layout
#     plt.tight_layout()
#     plt.show()

# # Example usage
# datasets = ['hellaswag', 'commonsense_qa', 'arc', 'piqa']
# create_scatterplot_grid(
#     'finweb_results.csv', 
#     datasets, 
#     step_interval=10000, 
#     pretraining_data=['C4', 'The Pile', 'RefineWeb', 'RedPajama2'],
#     min_step=50000
# )

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
import itertools

def create_performance_scatter(csv_file, dataset1, dataset2, step_interval=1000, pretraining_data=None, min_step=0, ax=None):
    """
    Create a scatter plot comparing performance on two datasets
    
    Args:
        csv_file: path to the CSV file containing the results
        dataset1: name of first dataset (x-axis)
        dataset2: name of second dataset (y-axis)
        step_interval: plot points at every step_interval steps (default: 1000)
        pretraining_data: list of specific pretraining datasets to include (default: None, includes all)
        min_step: minimum step to start plotting from (default: 0)
        ax: matplotlib axis to plot on (default: None, creates new figure)
    """
    # Load the data
    df = pd.read_csv(csv_file)
    
    # Filter steps based on interval and minimum step
    df = df[(df['steps'] % step_interval == 0) & (df['steps'] >= min_step)]
    
    # Filter by specific pretraining datasets if provided
    if pretraining_data is not None:
        df = df[df['runname'].isin(pretraining_data)]
    
    # Construct column names
    x_col = f"{dataset1}/acc"
    y_col = f"{dataset2}/acc"
    
    # Check if columns exist
    if x_col not in df.columns or y_col not in df.columns:
        raise ValueError(f"Columns {x_col} or {y_col} not found in the data.")
    
    # Create new figure if no axis provided
    if ax is None:
        plt.figure(figsize=(12, 8))
        ax = plt.gca()
    
    # Create scatter plot
    sns.scatterplot(data=df, x=x_col, y=y_col, hue='runname', alpha=0.6, ax=ax)
    
    # Dictionary to store R² values
    r2_values = {}
    
    # Add trend line for each runname and calculate R²
    for name in df['runname'].unique():
        subset = df[df['runname'] == name]
        
        # Calculate regression line
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            subset[x_col], subset[y_col]
        )
        r2 = r_value ** 2
        r2_values[name] = r2
        
        # Plot regression line
        sns.regplot(data=subset, x=x_col, y=y_col, 
                   scatter=False, 
                   label=f"{name} (R² = {r2:.3f})", 
                   ci=None,
                   ax=ax)
    
    # Customize plot
    ax.set_title(f'{dataset1} vs {dataset2}\n(Steps: {step_interval}, Min: {min_step})')
    ax.set_xlabel(f'{dataset1} Accuracy')
    ax.set_ylabel(f'{dataset2} Accuracy')
    
    # Return the axis and fit statistics
    return ax, r2_values

def create_scatterplot_grid(csv_file, datasets, step_interval=1000, pretraining_data=None, min_step=0):
    """
    Create a grid of scatter plots for each pair of downstream tasks.
    
    Args:
        csv_file: path to the CSV file containing the results
        datasets: list of dataset names to compare
        step_interval: plot points at every step_interval steps (default: 1000)
        pretraining_data: list of specific pretraining datasets to include (default: None, includes all)
        min_step: minimum step to start plotting from (default: 0)
    """
    num_datasets = len(datasets)
    fig, axes = plt.subplots(num_datasets, num_datasets, figsize=(30, 30))
    
    # Store all R² values
    all_stats = {}
    
    for i, j in itertools.combinations(range(num_datasets), 2):
        dataset1 = datasets[i]
        dataset2 = datasets[j]
        
        # Create scatter plot for the current pair of datasets
        ax, r2_values = create_performance_scatter(
            csv_file, dataset1, dataset2, 
            step_interval, pretraining_data, min_step,
            ax=axes[i, j]
        )
        
        # Store statistics
        all_stats[(dataset1, dataset2)] = r2_values
        
        # Remove legend for cleaner grid (except for rightmost plots)
        if j != num_datasets-1:
            ax.get_legend().remove()
    
    # Remove unused subplots (diagonal and lower triangle)
    for i in range(num_datasets):
        for j in range(num_datasets):
            if i >= j:
                fig.delaxes(axes[i, j])
    
    # Adjust layout
    plt.tight_layout()
    
    # Print all fit statistics
    print("\nDetailed Fit Statistics:")
    print("-" * 50)
    for (dataset1, dataset2), r2_values in all_stats.items():
        print(f"\nComparison: {dataset1} vs {dataset2}")
        for name, r2 in r2_values.items():
            print(f"{name}: R² = {r2:.3f}")
    
    plt.show()

# Example usage
datasets = ['hellaswag', 'commonsense_qa', 'arc', 'piqa', 'sciq', 'mmlu']
create_scatterplot_grid(
    'finweb_results.csv', 
    datasets, 
    step_interval=10000, 
    pretraining_data=['C4', 'The Pile', 'RefineWeb', 'RedPajama2'],
    min_step=50000
)

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
import itertools

def create_performance_scatter(csv_file, dataset1, dataset2, step_interval=1000, pretraining_data=None, min_step=0, ax=None):
    """
    Create a scatter plot comparing performance on two datasets
    
    Args:
        csv_file: path to the CSV file containing the results
        dataset1: name of first dataset (x-axis)
        dataset2: name of second dataset (y-axis)
        step_interval: plot points at every step_interval steps (default: 1000)
        pretraining_data: list of specific pretraining datasets to include (default: None, includes all)
        min_step: minimum step to start plotting from (default: 0)
        ax: matplotlib axis to plot on (default: None, creates new figure)
    """
    # Load the data
    df = pd.read_csv(csv_file)
    
    # Filter steps based on interval and minimum step
    df = df[(df['steps'] % step_interval == 0) & (df['steps'] >= min_step)]
    
    # Filter by specific pretraining datasets if provided
    if pretraining_data is not None:
        df = df[df['runname'].isin(pretraining_data)]
    
    # Construct column names
    x_col = f"{dataset1}/acc"
    y_col = f"{dataset2}/acc"
    
    # Check if columns exist
    if x_col not in df.columns or y_col not in df.columns:
        raise ValueError(f"Columns {x_col} or {y_col} not found in the data.")
    
    # Create new figure if no axis provided
    if ax is None:
        plt.figure(figsize=(12, 8))
        ax = plt.gca()
    
    # Create scatter plot
    scatter = sns.scatterplot(data=df, x=x_col, y=y_col, hue='runname', alpha=0.6, ax=ax)
    
    # Remove the scatter plot legend if it exists
    legend = scatter.get_legend()
    if legend:
        legend.remove()
    
    # Dictionary to store R² values
    r2_values = {}
    
    # Add trend line for each runname and calculate R²
    for name in df['runname'].unique():
        subset = df[df['runname'] == name]
        
        # Calculate regression line
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            subset[x_col], subset[y_col]
        )
        r2 = r_value ** 2
        r2_values[name] = r2
        
        # Plot regression line with R² in the legend
        sns.regplot(data=subset, x=x_col, y=y_col, 
                   scatter=False, 
                   label=f"{name} (R²={r2:.2f})", 
                   ci=None,
                   ax=ax)
    
    # Customize plot
    ax.set_title(f'{dataset1} vs {dataset2}\n(Steps: {step_interval}, Min: {min_step})')
    ax.set_xlabel(f'{dataset1} Accuracy')
    ax.set_ylabel(f'{dataset2} Accuracy')
    
    # Return the axis and fit statistics
    return ax, r2_values

def create_scatterplot_grid(csv_file, datasets, step_interval=1000, pretraining_data=None, min_step=0):
    """
    Create a grid of scatter plots for each pair of downstream tasks.
    
    Args:
        csv_file: path to the CSV file containing the results
        datasets: list of dataset names to compare
        step_interval: plot points at every step_interval steps (default: 1000)
        pretraining_data: list of specific pretraining datasets to include (default: None, includes all)
        min_step: minimum step to start plotting from (default: 0)
    """
    num_datasets = len(datasets)
    fig, axes = plt.subplots(num_datasets, num_datasets, figsize=(30, 30))
    
    for i, j in itertools.combinations(range(num_datasets), 2):
        dataset1 = datasets[i]
        dataset2 = datasets[j]
        
        # Create scatter plot for the current pair of datasets
        ax, _ = create_performance_scatter(
            csv_file, dataset1, dataset2, 
            step_interval, pretraining_data, min_step,
            ax=axes[i, j]
        )
        
        # Remove legend for cleaner grid (except for rightmost plots)
        if j != num_datasets-1:
            legend = ax.get_legend()
            if legend:
                legend.remove()
    
    # Remove unused subplots (diagonal and lower triangle)
    for i in range(num_datasets):
        for j in range(num_datasets):
            if i >= j:
                fig.delaxes(axes[i, j])
    
    # Adjust layout
    plt.tight_layout()
    plt.show()

# Example usage
datasets = ['hellaswag', 'commonsense_qa', 'arc', 'piqa', 'sciq', 'mmlu']
create_scatterplot_grid(
    'finweb_results.csv', 
    datasets, 
    step_interval=10000, 
    pretraining_data=['C4', 'The Pile', 'RefineWeb', 'RedPajama2'],
    min_step=50000
)