In [None]:
import glob
import os
import pandas as pd
import wandb

In [None]:
layer_name_to_exp = {
    "h3": "dfa/h3",
    "linear_transformer": "dfa/linear_transformer",
    "hyena": "dfa/hyena",
    "transformer": "dfa/transformer",
    "rwkv": "dfa/rwkv",
    "s4d": "dfa/s4d",
    "lstm": "dfa/lstm",
    "retention": "dfa/retnet",
    "mamba": "dfa/mamba",
}

In [None]:
api = wandb.Api(timeout=60)
entity, project = "akyurek", "interpret_dfa_all_probes_2500"
runs = api.runs(entity + "/" + project)

summary_list, config_list, name_list, attr_list = [], [], [], []
for run in runs:
    # .summary contains output keys/values for
    # metrics such as accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items()})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

    #
    attr_list.append(run._attrs)


runs_df = pd.DataFrame(
    {
        "summary": summary_list,
        "config": config_list,
        "name": name_list,
        "attr": attr_list,
    }
)

In [None]:
def get_nested_arg(x, args):
    for arg in args:
        try:
            x = x[arg]
        except:
            return None
    return x

In [None]:
for metric in ["acc", "error"]:
    runs_df[f"test/final_{metric}"] = runs_df.summary.map(
        lambda x: x.get(f"test/final_{metric}", None)
    )
for config in [
    "layer",
    "exp",
    "hidden_key",
    "ngram",
    "binary",
    "use_ratio"
]:
    config_parts = config.split(".")
    # get the value by nested index
    runs_df[config] = runs_df.config.map(lambda x: get_nested_arg(x, config_parts))

# remove columns
runs_df.drop(columns=["summary", "config", "attr"], inplace=True)
# remove row if both test/final_acc and test/final_err are None


In [None]:
runs_df.exp.unique()

In [None]:
runs_df.dropna(subset=["test/final_acc", "test/final_error"], how="all", inplace=True)

In [None]:
# make neurips conference quality plots
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

import scienceplots


plt.style.use(['science','ieee'])

plt.rcParams['xtick.top'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['legend.frameon'] = True
plt.rcParams['legend.framealpha'] = 1.0
plt.rcParams['xtick.minor.visible'] = False


In [None]:
# 302DBE
# 5E00C9
color_palette = """#f95d6a
#e84e5b
#ffa600
#ea8900
#d36d00
#bb5100
#665191
#4c31aa
#61b2e5
#5397cb
#477db1
#3b6396
#2f4b7c""".split(
    "\n"
)
# convert to rgb
color_palette = [matplotlib.colors.to_rgb(c) for c in color_palette]

In [None]:
model_names = {
    "transformer/2": "Transformer (2 layers)",
    "transformer/1": "Transformer (1 layers)",
    "transformer/4": "Transformer (4 layers)",
    "transformer/8": "Transformer (8 layers)",
    "transformer/12": "Transformer (12 layers)",
    "lstm": "LSTM",
    "hyena": "Hyena",
    "h3": "H3",
    "s4d": "S4",
    "linear_transformer": "Linear Transformer",
    "rwkv": "RWKV",
    "retnet": "RetNet",
    "mamba": "Mamba",
}

In [None]:
hue_order = sorted(model_names.values())

In [None]:
hue_order =[
 'LSTM',
 'RWKV',
 'S4',
 'H3',
 'Hyena',
 "Mamba",
 'RetNet',
 'Linear Transformer',
 'Transformer (1 layers)',
 'Transformer (2 layers)',
 'Transformer (4 layers)',
 'Transformer (8 layers)',
 'Transformer (12 layers)']

In [None]:
runs_df.dropna(subset=["test/final_error"])

In [None]:
# grouped bar chart where groups are "ngram"s and bars are different "exp"s in df
# aggregate over "layer"s and hidden_key"s for the minimum test/final_err
for use_ratio in [True, False]:
    df = runs_df.copy()
    # filter
    df = df[df.use_ratio == use_ratio]
    # drop nan final_err
    df = df.dropna(subset=["test/final_error"])
    data = df.loc[
        df.groupby(["ngram", "exp"])[
            "test/final_error"
        ].idxmin()]
    data = data.replace({"exp": model_names})
    data = data.replace({"ngram": {1: "1-gram", 2: "2-gram", 3: "3-gram", 4: "4-gram"}})
    sns.catplot(
        data=data,
        x="ngram",
        y="test/final_error",
        hue="exp",
        kind="bar",
        height=2,
        aspect=3,
        legend=False,
        palette=color_palette,
        hue_order=hue_order,
    )
    # set y limits
    plt.ylim([0, data['test/final_error'].max() * 1.1])

    # if use_ratio:
    #     # remove xticks and xlabel
    #     plt.xticks([])
    #     plt.xlabel(None)
    #     # remove yticks and ylabel
    #     plt.yticks([])
    #     plt.ylabel(None)
    # else:
    #     # ylabel
    plt.ylabel("Error Percentage")
    plt.xlabel(None)
    # title
    normalized_text = "Normalized" if use_ratio else "Absolute"
    plt.title(f"{normalized_text} Ngram Counts")

    # save pdf
    plt.savefig(f"figures/lr_ngram_{normalized_text}.pdf")




# grouped bar chart where groups are "ngram"s and bars are different "exp"s in df
# aggregate over "layer"s and hidden_key"s for the minimum test/final_err
df = runs_df.copy()
# drop nan final_err
df = df.dropna(subset=["test/final_acc"])
data = df.loc[
    df.groupby(["ngram", "exp"])[
        "test/final_acc"
    ].idxmax()]
data = data.replace({"exp": model_names})
data = data.replace({"ngram": {0: "Same State", 2: "2-gram", 3: "3-gram"}})

sns.catplot(
    data=data,
    x="ngram",
    y="test/final_acc",
    hue="exp",
    kind="bar",
    height=2,
    aspect=3,
    legend='full',
    palette=color_palette,
    hue_order=hue_order,
)
# ylabel
plt.ylabel("Accuracy")
plt.xlabel(None)
# title
plt.title("Ngram Existence and Same State")
plt.savefig(f"figures/lr_classification_legend.pdf")



In [None]:
# plot error of ngrams w.r.t layers of the 12 layer Transformer
use_ratio = True
df = runs_df.copy()
# filter
df = df[df.exp == "transformer/8"]
# filter ngram=0
df = df[(df.ngram != 0) & (df.use_ratio == use_ratio)]
# drop nan final_err for only 1 grams
# df = pd.concat((df[df.ngram == 1].dropna(subset=["test/final_error"]), df[df.ngram != 1].dropna(subset=["test/final_acc"])))
# convert final_error to 1-final_Err and save to final_acc if final_acc is nan
# df["test/final_acc"] = df["test/final_acc"].fillna(1 - df["test/final_error"])
df = df.dropna(subset=["test/final_error"])
df.loc[(df['hidden_key'] == "attention_contexts"), "layer"]  += 1

# plot max error of ngrams w.r.t layers of the 12 layer Transformer
data = df.loc[
    df.groupby(["ngram", "layer"])[
        "test/final_error"
    ].idxmin()]

# plot multiline plot for each ngram
sns.lineplot(
    data=data,
    x="layer",
    y="test/final_error",
    hue="ngram",
    style="ngram",
    markers=True,
    dashes=False,
    palette=[color_palette[0], color_palette[2], color_palette[9]],
)

# ylabel
plt.ylabel("Error Percentage")
plt.xlabel("Hidden Outputs of Transformer with 8 Layers")
# save fig
plt.savefig(f"figures/lr_ngram_layers.pdf")





In [None]:
df.head()