# Deprecated, please use `create_df.py` and `plot_grids.ipynb`

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import utils
from mmlu_utils import compute_avg_losses, compute_avg_task_losses_and_accuracies

sns.set_theme(style="whitegrid")

# Get data and basic analysis

In [None]:
# get df
# dir_logs = '/fast/pmayilvahanan/lm_logs/lingua_old/'
dir_logs = "/fast/pmayilvahanan/lm_logs/lingua/"
df = utils.gather_experiment_data(dir_logs)

In [None]:
# save df
df.to_csv("/lustre/fast/fast/pmayilvahanan/llm_line/code/llm_line/notebooks/df.csv")

In [None]:
df_copy = copy.deepcopy(df)

In [None]:
# drop nan values
df.dropna(subset=["train_loss", "test_loss"], how="any", inplace=True)

In [None]:
df["tokenizer"]

In [None]:
df.head()

In [None]:
# at the moment, only analyze data for jobs that are (almost) completed

# After loading df but before other processing
min_tokens = 7.5e9  # 8.4B tokens minimum, adjust as needed

# Get max tokens for each model name
max_tokens_per_model = df.groupby("name")["tokens"].max()

# Keep only models that reached the minimum token count
valid_models = max_tokens_per_model[max_tokens_per_model >= min_tokens].index
df = df[df["name"].isin(valid_models)]

In [None]:
valid_models

In [None]:
# add_feature, within_chinchilla (if tokens < size*20)
df["within_chinchilla"] = df["tokens"] < df["size"] * 20

In [None]:
# mmlu avg (computing average of all mmlu task losses)
df[
    [
        "mmlu/loss",
        "mmlu_stem/loss",
        "mmlu_other/loss",
        "mmlu_humanities/loss",
        "mmlu_social_sciences/loss",
    ]
] = df.apply(extract_mmlu_losses, axis=1)

In [None]:
# compute average of losses and accuracies of given tasks
task_columns = [
    "hellaswag",
    "piqa",
    "arc_easy",
    "arc_challenge",
    "commonsense_qa",
    "openbookqa",
    "winogrande",
    "social_iqa",
    "mmlu",
]
task_columns_loss = [f"{t}/loss" for t in task_columns]
task_columns_acc = [f"{t}/acc" for t in task_columns]
df["avg/loss"] = df[task_columns_loss].mean(axis=1)
df["avg/acc"] = df[task_columns_acc].mean(axis=1)

In [None]:
# Or load from csv
df = pd.read_csv("df.csv")

In [None]:
# drop nan values
df.dropna(subset=["train_loss", "test_loss"], how="any", inplace=True)

# compute avg losses and accuracies
tasks = [
    "hellaswag",
    "piqa",
    "arc_easy",
    "arc_challenge",
    "commonsense_qa",
    "openbookqa",
    "winogrande",
    "social_iqa",
    "mmlu",
]
df = compute_avg_losses(df)
df = compute_avg_task_losses_and_accuracies(df, tasks=tasks)
tasks.remove("hellaswag")
df = compute_avg_task_losses_and_accuracies(
    df, tasks=tasks, col_name="avg_minus_hellaswag"
)

In [None]:
# Does overfitting happen?
plt.plot(df[df["name"] == "mamba_420M_fw_edu_8.4BT"]["mmlu/loss"])
plt.xlabel("steps")
plt.ylabel("val loss")
plt.show()

In [None]:
df[df["pretraining_data"] == "pile_uc"]["commonsense_qa/loss"].min()

In [None]:
df[df["pretraining_data"] == "pile_uc"]["commonsense_qa/acc"].max()

# Correlation Analysis

In [None]:
basic_columns = ["name", "arch", "size", "dim", "n_layers", "pretraining_data"]
task_columns = [
    "hellaswag",
    "piqa",
    "arc_easy",
    "arc_challenge",
    "openbookqa",
    "commonsense_qa",
    "winogrande",
    "social_iqa",
    "mmlu",
    "avg",
]
val_columns = [
    "slimpajama_val_loss",
    "c4_val_loss",
    "pile_uc_val_loss",
    "fineweb_edu_100bt_val_loss",
    "refineweb_val_loss",
]

# Correlation targets: val_columns + each "task/acc" + each "task/loss"
corr_targets = (
    val_columns
    + [f"{t}/acc" for t in task_columns]
    + [f"{t}/loss" for t in task_columns]
)


# Helper to compute correlation if there are at least two non-null data points
def safe_corrcoef(x, y):
    if x.notna().sum() < 2 or y.notna().sum() < 2:
        return np.nan
    return np.corrcoef(x, y)[0, 1]


corr_rows = []

for model_id, group in df.groupby(basic_columns, dropna=False):
    # model_id is a tuple of values for (name, arch, size, dim, n_layers, pretraining_data)
    for anchor_loss in ["test_loss", "train_loss", "hellaswag/loss"]:
        if anchor_loss not in group.columns:
            continue

        # Build one row per (model, anchor_loss)
        row = dict(zip(basic_columns, model_id))
        row["anchor"] = anchor_loss

        for target_col in corr_targets:
            row[f"{target_col}"] = (
                safe_corrcoef(group[anchor_loss], group[target_col])
                if target_col in group.columns
                else np.nan
            )

        corr_rows.append(row)

corr_all = pd.DataFrame(corr_rows)
corr_all.head()

In [None]:
# let's just look at 'test_loss' (in-d loss) correlations atm
corr = corr_all[corr_all["anchor"] == "test_loss"]

In [None]:
# filtered
task = "mmlu"
metric = "loss"
threshold = 0.7
filtered_corr = corr[corr[f"{task}/{metric}"] < threshold]

# report
columns_to_report = (
    ["name", "pretraining_data"] + val_columns + [f"{t}/loss" for t in task_columns]
)
print(len(filtered_corr[columns_to_report]))
filtered_corr[columns_to_report]

In [None]:
# Correlation change


def compare_correlations(df, base_name, intervention_name, col1, col2):
    """
    Given a DataFrame, two 'name' values (base and intervention), and two columns,
    prints the correlation for each name separately, the combined correlation,
    and the percentage change in correlation from base to combined.
    """
    import numpy as np

    base_df = df[df["name"] == base_name]
    int_df = df[df["name"] == intervention_name]
    comb_df = df[df["name"].isin([base_name, intervention_name])]

    corr_base = base_df[col1].corr(base_df[col2])
    corr_int = int_df[col1].corr(int_df[col2])
    corr_comb = comb_df[col1].corr(comb_df[col2])

    # Compute percentage change from base correlation to combined
    if corr_base and not np.isnan(corr_base) and corr_base != 0:
        pct_change = 100.0 * (corr_comb - corr_base) / abs(corr_base)
    else:
        pct_change = None

    print(f"Base ({base_name}) correlation: {corr_base:.3f}")
    print(f"Intervention ({intervention_name}) correlation: {corr_int:.3f}")
    print(f"Combined correlation: {corr_comb:.3f}")
    if pct_change is not None:
        print(f"Percentage change from base to combined: {pct_change:.2f}%")
    else:
        print("Percentage change from base to combined: N/A (invalid base corr).")

In [None]:
temp = df["name"].unique()
temp.sort()
print(temp)

In [None]:
# depth intervention

for dataset in ["fw_edu", "c4", "pile_uc"]:
    base_name = "llama_416M_{dataset}_8.4BT".format(dataset=dataset)
    intervention_name = "mamba_420M_{dataset}_8.4BT".format(dataset=dataset)
    col1 = "test_loss"
    col2 = "avg/loss"
    compare_correlations(df, base_name, intervention_name, col1, col2)

In [None]:
for dataset in ["fw_edu", "c4", "pile_uc"]:
    base_name = "llama_416M_{dataset}_8.4BT".format(dataset=dataset)
    intervention_name = "mamba_420M_{dataset}_8.4BT".format(dataset=dataset)
    col1 = "test_loss"
    col2 = "avg/loss"
    compare_correlations(df, base_name, intervention_name, col1, col2)

In [None]:
for dataset in ["fw_edu", "c4", "pile_uc"]:
    intervention_name = "llama_416M_{dataset}_8.4BT".format(dataset=dataset)
    base_name = "gpt2_420M_{dataset}_8.4BT".format(dataset=dataset)
    col1 = "hellaswag/loss"
    col2 = "piqa/loss"
    compare_correlations(df, base_name, intervention_name, col1, col2)

In [None]:
def create_comparison_plots(
    df, col1, col2, primary="arch", secondary="dim", tertiary="n_layers", save_path=None
):
    """
    Creates flexible comparison plots with any combination of grouping variables.

    Args:
        df: DataFrame with the data
        col1, col2: column names for x and y axes
        primary: column to group separate plots by ('arch', 'dim', or 'n_layers')
        secondary: column for subplots ('arch', 'dim', or 'n_layers')
        tertiary: column for point shapes ('arch', 'dim', or 'n_layers')
        save_path: optional path to save figures
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy import stats

    # Dynamic grouping values from data
    group_values = {
        "arch": sorted(df["arch"].unique()),
        "dim": sorted(df["dim"].unique()),
        "n_layers": sorted(df["n_layers"].unique()),
    }

    # Verify valid grouping columns
    for col in [primary, secondary, tertiary]:
        if col not in group_values:
            raise ValueError(
                f"Invalid grouping column: {col}. Must be one of {list(group_values.keys())}"
            )

    colors = {"c4": "blue", "pile_uc": "red", "fineweb_edu_100bt": "green"}
    markers = ["o", "s", "^", "D", "v", "<", ">", "p", "h", "8"]
    marker_dict = dict(
        zip(group_values[tertiary], markers[: len(group_values[tertiary])])
    )

    for prim_val in group_values[primary]:
        fig, axes = plt.subplots(
            1,
            len(group_values[secondary]),
            figsize=(5 * len(group_values[secondary]), 5),
        )
        if len(group_values[secondary]) == 1:
            axes = [axes]
        fig.suptitle(f"{primary}={prim_val}: {col2} vs {col1}")

        legend_elements = []

        for sec_val, ax in zip(group_values[secondary], axes):
            mask = (df[primary] == prim_val) & (df[secondary] == sec_val)
            subset = df[mask]

            for p_data in colors:
                p_data_mask = subset["pretraining_data"] == p_data
                p_data_subset = subset[p_data_mask]

                for tert_val in sorted(p_data_subset[tertiary].unique()):
                    tert_mask = p_data_subset[tertiary] == tert_val
                    data = p_data_subset[tert_mask]

                    if len(data) > 0:
                        scatter = ax.scatter(
                            data[col1],
                            data[col2],
                            c=colors[p_data],
                            marker=marker_dict[tert_val],
                            alpha=0.6,
                        )

                        # Add to legend if new combination
                        legend_key = (
                            p_data,
                            str(tert_val),
                        )  # Convert to string for comparison
                        if legend_key not in [
                            (
                                le.get_label().split(", ")[0],
                                le.get_label().split(", ")[1].split()[0],
                            )
                            for le in legend_elements
                        ]:
                            scatter.set_label(f"{p_data}, {tert_val} {tertiary}")
                            legend_elements.append(scatter)

                if len(p_data_subset) > 1:
                    slope, intercept, r_value, _, _ = stats.linregress(
                        p_data_subset[col1], p_data_subset[col2]
                    )
                    x_range = np.linspace(
                        p_data_subset[col1].min(), p_data_subset[col1].max(), 100
                    )
                    y_range = slope * x_range + intercept
                    ax.plot(x_range, y_range, c=colors[p_data])

                    mid_idx = len(x_range) // 2
                    ax.annotate(
                        f"R² = {r_value**2:.2f}",
                        xy=(x_range[mid_idx], y_range[mid_idx]),
                        xytext=(10, 10),
                        textcoords="offset points",
                        color=colors[p_data],
                        bbox=dict(facecolor="white", edgecolor="none", alpha=0.7),
                    )

            ax.set_title(f"{secondary}={sec_val}")
            ax.set_xlabel(col1)
            ax.set_ylabel(col2)

        fig.legend(
            handles=legend_elements, bbox_to_anchor=(1.05, 0.5), loc="center left"
        )
        plt.tight_layout()

        if save_path:
            plt.savefig(f"{save_path}_{prim_val}.png", bbox_inches="tight")
        plt.show()

In [None]:
# # Original way (by arch, then dim)
# create_comparison_plots(df, 'test_loss', 'hellaswag/loss',
#                        primary='arch', secondary='dim', tertiary='n_layers')

# # By dimension, then architecture
# create_comparison_plots(df, 'test_loss', 'hellaswag/loss',
#                        primary='dim', secondary='arch', tertiary='n_layers')

# # By n_layers, then dimension
# create_comparison_plots(df, 'test_loss', 'hellaswag/loss',
#                        primary='n_layers', secondary='dim', tertiary='arch')

create_comparison_plots(
    df,
    "hellaswag/loss",
    "arc_challenge/loss",
    primary="dim",
    secondary="n_layers",
    tertiary="arch",
)

In [None]:
def create_grid_comparison(df, columns, max_checkpoints_per_group=100, save_path=None):
    """
    Creates a grid of scatter plots comparing each column against others.

    Args:
        df: DataFrame with the data
        columns: list of column names to compare
        max_checkpoints_per_group: maximum number of checkpoints to sample per arch-pretraining combination
        save_path: optional path to save figure
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy import stats
    import numpy as np

    # Sample checkpoints if needed
    sampled_df = df.copy()
    if max_checkpoints_per_group:
        samples = []
        for arch in df["arch"].unique():
            for p_data in df["pretraining_data"].unique():
                mask = (df["arch"] == arch) & (df["pretraining_data"] == p_data)
                group_data = df[mask]
                if len(group_data) > max_checkpoints_per_group:
                    samples.append(
                        group_data.sample(n=max_checkpoints_per_group, random_state=42)
                    )
                else:
                    samples.append(group_data)
        sampled_df = pd.concat(samples)

    n = len(columns)
    fig, axes = plt.subplots(n, n, figsize=(5 * n, 5 * n))

    # Style settings
    colors = {"c4": "blue", "pile_uc": "red", "fineweb_edu_100bt": "green"}
    markers = {"llama": "o", "mamba": "s", "gpt2": "^"}

    # Create plots
    for i, col1 in enumerate(columns):
        for j, col2 in enumerate(columns):
            ax = axes[i, j]

            if i != j:  # Skip diagonal
                # Plot points for each architecture and pretraining data
                for arch in sampled_df["arch"].unique():
                    arch_mask = sampled_df["arch"] == arch
                    arch_data = sampled_df[arch_mask]

                    for p_data in colors:
                        mask = arch_data["pretraining_data"] == p_data
                        data = arch_data[mask]

                        if len(data) > 0:
                            ax.scatter(
                                data[col1],
                                data[col2],
                                c=colors[p_data],
                                marker=markers[arch],
                                alpha=0.6,
                                label=f"{arch}-{p_data}",
                            )

                # Add one fit line for all data combined
                if len(sampled_df) > 1:
                    slope, intercept, r_value, _, _ = stats.linregress(
                        sampled_df[col1], sampled_df[col2]
                    )
                    x_range = np.linspace(
                        sampled_df[col1].min(), sampled_df[col1].max(), 100
                    )
                    y_range = slope * x_range + intercept
                    # Use black color for the overall correlation line
                    ax.plot(x_range, y_range, c="black", linestyle="--")

                    # Add R² annotation
                    mid_idx = len(x_range) // 2
                    ax.annotate(
                        f"R² = {r_value**2:.2f}",
                        xy=(x_range[mid_idx], y_range[mid_idx]),
                        xytext=(10, 10),
                        textcoords="offset points",
                        color="black",
                        bbox=dict(facecolor="white", edgecolor="none", alpha=0.7),
                    )

                ax.set_xlabel(col1)
                ax.set_ylabel(col2)
                ax.grid(True, alpha=0.3)
            else:
                # Show column name in diagonal
                ax.text(
                    0.5,
                    0.5,
                    col1,
                    horizontalalignment="center",
                    verticalalignment="center",
                    transform=ax.transAxes,
                    fontsize=12,
                )
                ax.set_xticks([])
                ax.set_yticks([])

    # Add legend to the first subplot
    handles, labels = axes[0, 1].get_legend_handles_labels()
    unique_labels = []
    unique_handles = []
    for h, l in zip(handles, labels):
        if l not in unique_labels:
            unique_labels.append(l)
            unique_handles.append(h)
    fig.legend(
        unique_handles, unique_labels, bbox_to_anchor=(1.02, 0.5), loc="center left"
    )

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
    plt.show()


# List of columns to compare
columns_to_compare = [
    "arc_challenge/loss",
    "arc_easy/loss",
    "hellaswag/loss",
    "piqa/loss",
    "openbookqa/loss",
    "winogrande/loss",
]
# plot only for within_chinchilla = True
df_within_chinchilla = df[df["within_chinchilla"] == True]

# Create the grid plot
create_grid_comparison(
    df_within_chinchilla,
    columns_to_compare,
    save_path="/lustre/fast/fast/pmayilvahanan/llm_line/code/llm_line/notebooks/grid_comparison.png",
)

In [None]:
os.getcwd()