# Setup and imports

In [None]:
# imports
import numpy as np
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd
import matplotlib


import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [None]:
from analysis import sweep2df, plot_typography, stats2string, RED, BLUE, rule_stats2string_per_model, grouped_rule_stats
from llm_non_identifiability.data import A_token, B_token, OPENING_BRACKET_token, OPENING_PARENTHESIS_token, CLOSING_BRACKET_token, CLOSING_PARENTHESIS_token

In [None]:
USETEX = True

In [None]:
plt.rcParams.update(bundles.icml2022(usetex=USETEX))
# plt.rcParams.update({
#     'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
#                             r'\usepackage{amsmath}'] # boldsymbol
# })

In [None]:
plot_typography(usetex=USETEX, small=14, medium=14, big=18)


In [None]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "rule_extrapolation"

# W&B API
api = wandb.Api(timeout=200)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

In [None]:
def get_sweep_stats(sweep_id, file_prefix, save=False, load=True, entity=ENTITY, project=PROJECT):
    api = wandb.Api(timeout=400)
    sweep = api.sweep(f"{entity}/{project}/{sweep_id}")
    filename = f"{file_prefix}_{sweep_id}"
    df,train_loss,val_loss,val_kl,val_accuracy,finised,ood_finised,sos_finised,r1,r2,grammatical,ood_r1,ood_r1_completion,ood_r2,ood_grammatical,sos_r1,sos_r2,sos_grammatical= sweep2df(sweep.runs, filename, save=save, load=load)
    return df, grouped_rule_stats(df)

## baN

In [None]:
ban_no_transformer_df, _ = get_sweep_stats(sweep_id="frrbkgg0", file_prefix="ban_no_transformer", save=True, load=True)
ban_transformer_df, _ = get_sweep_stats(sweep_id="9b9a0xtn", file_prefix="ban_transformer", save=True, load=True)
ban_lstm_df, _ = get_sweep_stats(sweep_id="vzcozopj", file_prefix="ban_lstm", save=True, load=True)

In [None]:
ban_lin_df, _ = get_sweep_stats(sweep_id="42r89e8x", file_prefix="ban_lin", save=True, load=True)

In [None]:
# filter out linear models (incorrect logs)
ban_no_transformer_df = ban_no_transformer_df[ban_no_transformer_df.model != "linear"]
ban_transformer_df = ban_transformer_df[ban_transformer_df.model != "linear"]
ban_lstm_df = ban_lstm_df[ban_lstm_df.model != "linear"]




In [None]:
ban_df = pd.concat([ban_no_transformer_df, ban_transformer_df, ban_lstm_df, ban_lin_df])
ban_stats = grouped_rule_stats(ban_df)

## bbaN

In [None]:
bban_transformer_mamba_df, _ = get_sweep_stats(sweep_id="jjjtbh24", file_prefix="bban_transformer_mamba", save=True, load=False)
bban_all_df, _ = get_sweep_stats(sweep_id="unc8bv65", file_prefix="bban_all", save=True, load=False)

In [None]:
bban_lin_df, _ = get_sweep_stats(sweep_id="dnpv6gpm", file_prefix="bban_lin", save=True, load=False)

In [None]:
# filter out linear models (incorrect logs)
bban_transformer_mamba_df = bban_transformer_mamba_df[bban_transformer_mamba_df.model != "linear"]
bban_all_df = bban_all_df[bban_all_df.model != "linear"]


In [None]:
bban_df = pd.concat([bban_transformer_mamba_df, bban_all_df, bban_lin_df])
bban_stats = grouped_rule_stats(bban_df)

## aNbN

In [None]:
anbn_transformer_df, _ = get_sweep_stats(sweep_id="vn4yrcl8", file_prefix="anbn_transformer", save=True, load=True)

In [None]:
# filter out linear models (incorrect logs)
anbn_transformer_df = anbn_transformer_df[anbn_transformer_df.model != "linear"]

In [None]:
anbn_lstm_df, _ = get_sweep_stats(sweep_id="t4yzbech", file_prefix="anbn_lstm", save=True, load=True)

In [None]:
anbn_lstm_df = anbn_lstm_df[anbn_lstm_df.model != "linear"]

In [None]:
anbn_mamba_df, _ = get_sweep_stats(sweep_id="o27zaphz", file_prefix="anbn_mamba", save=True, load=True)

In [None]:
anbn_mamba_df = anbn_mamba_df[anbn_mamba_df.model != "linear"]


In [None]:
anbn_no_transformer_df, _ = get_sweep_stats(sweep_id="nfrfpkqm", file_prefix="anbn_no_transformer", save=True, load=True)

In [None]:
anbn_no_transformer_df = anbn_no_transformer_df[anbn_no_transformer_df.model != "linear"]

In [None]:
anbn_lin_df, _ = get_sweep_stats(sweep_id="8lijfxk2", file_prefix="anbn_lin", save=True, load=True)

In [None]:
anbn_df = pd.concat([anbn_transformer_df, anbn_lstm_df, anbn_mamba_df, anbn_no_transformer_df, anbn_lin_df])
anbn_stats = grouped_rule_stats(anbn_df)

## aNbNcN

In [None]:
anbncn_df, _ = get_sweep_stats(sweep_id="6m1qb70e", file_prefix="anbncn", save=True, load=True)
anbncn_lin_mamba_df, _ = get_sweep_stats(sweep_id="vtji6cx4", file_prefix="anbncn_lin_mamba", save=True, load=True)

In [None]:
anbncn_lin_df, _ = get_sweep_stats(sweep_id="p135j0eg", file_prefix="anbncn_lin", save=True, load=True)

In [None]:
anbncn_df = anbncn_df[anbncn_df.model != "linear"]
anbncn_lin_mamba_df = anbncn_lin_mamba_df[anbncn_lin_mamba_df.model != "linear"]


In [None]:
anbncn_df_merged = pd.concat([anbncn_df, anbncn_lin_mamba_df, anbncn_lin_df])
anbncn_stats = grouped_rule_stats(anbncn_df_merged)

## Matched brackets and parentheses

In [None]:
dyck_df, _ = get_sweep_stats(sweep_id="eruf1l2q", file_prefix="dyck", save=True, load=True)
dyck_df = dyck_df[dyck_df.model != "linear"]

In [None]:
dyck_lin_df, _ = get_sweep_stats(sweep_id="gw4fwwsr", file_prefix="dyck_lin", save=True, load=True)

In [None]:
dyck_df = pd.concat([dyck_df, dyck_lin_df])
dyck_stats = grouped_rule_stats(dyck_df)

## Human study

In [None]:
human_df = pd.read_excel("human_study.xlsx")

# fill nan with empty string
human_df = human_df.fillna("")

# fill "already completed" with empty string
human_df = human_df.replace("already completed", "")

In [None]:
prompts = list(human_df.columns[1:])
# remove whitespace from beetween characters in the prompts
prompts = ["".join(prompt.split()) for prompt in prompts]

In [None]:
# overwrite columns names in df
human_df.columns = ["timestamp"] +prompts

In [None]:
A = A_token.item()
B = B_token.item()
OB = OPENING_BRACKET_token.item()
OP = OPENING_PARENTHESIS_token.item()
CB = CLOSING_BRACKET_token.item()
CP = CLOSING_PARENTHESIS_token.item()

def char2token(char):
    char = char.upper()
    if char == "A":
        return A
    elif char == "B":
        return B
    elif char == "(":
        return OP
    elif char == ")":
        return CP
    elif char == "[":
        return OB
    elif char == "]":
        return CB
    else:
        return None

In [None]:
# tokenize prompts
prompts_tokenized = [[char2token(c) for c in prompt] for prompt in prompts]



In [None]:
tokenized_human_df = human_df.copy()

In [None]:
for prompt, prompt_tokenized in zip(prompts, prompts_tokenized):

    # get the column by prompt
    col = human_df[prompt]
    col_stripped = ["".join(prompt.split()) for prompt in col]
    col_tokenized = [[char2token(c) for c in prompt] for prompt in col_stripped]

    tokenized_human_df[prompt] = col_tokenized



In [None]:
anbn_r1_human = []
anbn_r2_completion_human = []

dyck_r1_human = []
dyck_r2_completion_human = []

ban_r1_human = []
ban_r2_completion_human = []
from llm_non_identifiability.data import check_same_number_as_bs, check_as_before_bs, check_even_number_of_as, check_begins_with_b, check_matched_brackets, check_matched_parentheses
import torch

for idx, (prompt, prompt_tokenized) in enumerate(zip(prompts, prompts_tokenized)):
    # get the column by prompt
    col = tokenized_human_df[prompt]

    # iterate over the rows
    for completion in col:

        if None in completion:
            continue

        # add the tokenized prompt
        completed_prompt = torch.tensor(prompt_tokenized.copy() + completion)
        completion = torch.tensor(completion)
        # print(completion)

        if idx <= 4: # anbn
            anbn_r1_human.append(check_same_number_as_bs(completed_prompt))
            anbn_r2_completion_human.append(True if len(completion)==0 else check_as_before_bs(completion))
        elif idx <=9    : # dyck
            dyck_r1_human.append(check_matched_brackets(completed_prompt[2:]))
            dyck_r2_completion_human.append(True if len(completion)==0 else check_matched_parentheses(completion))
        else: # ban
            ban_r1_human.append(check_even_number_of_as(completed_prompt))
            ban_r2_completion_human.append(True if len(completion)==0 else check_begins_with_b(completion))



In [None]:
# print accuracies
print(f"anbn R1: {np.mean(anbn_r1_human)}")
print(f"anbn R2: {np.mean(anbn_r2_completion_human)}")
print(f"dyck R1: {np.mean(dyck_r1_human)}"
)
print(f"dyck R2: {np.mean(dyck_r2_completion_human)}")
print(f"ban R1: {np.mean(ban_r1_human)}")
print(f"ban R2: {np.mean(ban_r2_completion_human)}")

In [None]:
human_stats = {}
human_stats["baN"] = {
    "ood_rule_1": np.mean(ban_r1_human),
    "ood_rule_2_completion": np.mean(ban_r2_completion_human)
}

human_stats["aNbN"] = {
    "ood_rule_1": np.mean(anbn_r1_human),
    "ood_rule_2_completion": np.mean(anbn_r2_completion_human)
}

# human_stats["Dyck"] = {
#     "ood_rule_1": np.mean(dyck_r1_human),
#     "ood_rule_2_completion": np.mean(dyck_r2_completion_human)
# }

## Chance levels

In [None]:
chance_stats = {}
chance_stats["baN"] = {
    "ood_rule_1": 0.5,
    "ood_rule_2_completion": 1./3
}

chance_stats["bbaN"] = {
    "ood_rule_1": 0.5,
    "ood_rule_2_completion": 0.75
}


chance_stats["aNbN"] = {
    "ood_rule_1": 0.154,
    "ood_rule_2_completion": 0.4445
}
chance_stats["Dyck"] = {
    "ood_rule_1": 0.1273,
    "ood_rule_2_completion": 0.382
}

chance_stats["aNbNcN"] = {
    "ood_rule_1": 0.00334,
    "ood_rule_2_completion": 0.5925
}




# Plots

## helper functions


In [None]:
def plot_loss_vs_rules(df, stats, cmap="coolwarm", TICK_PADDING=2, LABELPAD=1, filename=None):

    colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red"
        }


    fig = plt.figure(figsize=figsizes.icml2022_full(nrows=1, ncols=2)['figure.figsize'])
    ax = fig.add_subplot(121)
    ax.grid(True, which="both", ls="-.")
    ax.set_axisbelow(True)
    for model in stats["val_loss"].groups.keys():

        im = ax.scatter(df[df.model == model].min_val_loss,
                        100 * df[df.model == model].ood_rule_1_accuracy4min_val_loss, c=colors[model], label=model.capitalize() if model != "lstm" else model.upper())
    ax.set_ylabel("R1 \%", labelpad=LABELPAD)
    ax.set_xlabel("Minimum test loss", labelpad=LABELPAD)
    # plt.legend()
    ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
    ax = fig.add_subplot(122)
    ax.grid(True, which="both", ls="-.")
    ax.set_axisbelow(True)
    for model in stats["val_loss"].groups.keys():
        im = ax.scatter(df[df.model == model].min_val_loss,
                        100 * df[df.model == model].ood_rule_2_completion_accuracy4min_val_loss, c=colors[model], label=model.capitalize() if model != "lstm" else model.upper())
    ax.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
    ax.set_xlabel("Minimum test loss", labelpad=LABELPAD)
    plt.legend(loc="center right", bbox_to_anchor=(1.75, 0.5))
    ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

    if filename is not None:
        plt.savefig(f"{filename}.svg")

## baN

In [None]:
plot_loss_vs_rules(ban_df, ban_stats, filename="ban_loss_vs_rules")

In [None]:
rule_stats2string_per_model(ban_stats, include_r2=False)

## bbaN

In [None]:
plot_loss_vs_rules(bban_df, bban_stats, filename="bban_loss_vs_rules")

In [None]:
rule_stats2string_per_model(bban_stats)

## aNbN

In [None]:
plot_loss_vs_rules(anbn_df, anbn_stats, filename="anbn_loss_vs_rules")

In [None]:
rule_stats2string_per_model(anbn_stats)

## aNbNcN

In [None]:
plot_loss_vs_rules(anbncn_df, anbncn_stats, filename="anbncn_loss_vs_rules")

In [None]:
rule_stats2string_per_model(anbncn_stats)

## Matched brackets and parentheses

In [None]:
plot_loss_vs_rules(dyck_df, dyck_stats, filename="dyck_loss_vs_rules")

In [None]:
rule_stats2string_per_model(dyck_stats)

## Plot for all languages

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"


stats_dict = {
    "baN": ban_stats,
    "bbaN": bban_stats,
    "aNbN": anbn_stats,
    "aNbNcN": anbncn_stats,
    "Dyck": dyck_stats,
}

labels = [
    r"$b\alpha$",
    r"$b^na^{2m}$",
    r"$a^nb^n$",
    r"$a^nb^nc^n$",
    "Dyck"
]

colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red",
    "human" : "black",
    # "chance": "purple"
        }


x_pos = list(range(len(stats_dict)))
x_offsets = [-0.3, -0.15, .0, .15, 0.3]
x_factor = 1/ len(stats_dict)
width = 0.35
x_stretch = 3
x_pos = [x*x_stretch for x in x_pos]

fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=3, tight_layout=True, rel_width=2)['figure.figsize'])

ax = fig.add_subplot(121)
# ax.grid(True, which="both", ls="-.")
# ax.set_axisbelow(True)

models =  colors.keys()
for x, (grammar, stats) in enumerate(stats_dict.items()):
    #
    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_1"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=.25)
        ax.add_patch(rectangle)

    for i, model in enumerate(models):

        if model == "human":
            if grammar in human_stats.keys():
                mean = human_stats[grammar]["ood_rule_1"]
                std = 0
            else:
                continue
        else:
            if model not in stats["ood_rule_1"].groups.keys():
                continue
            mean = stats["ood_rule_1"].get_group(model).mean()
            std = stats["ood_rule_1"].get_group(model).std()



        if mean < 0.01:
            ax.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)

# set xtick names
ax.set_xticks(x_pos)
ax.set_xticklabels(labels)
ax.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.2)

ax2 = fig.add_subplot(122)

for x, (grammar, stats) in enumerate(stats_dict.items()):

    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_2_completion"]
        # ax2.plot([x_stretch*(x-0.5), x_stretch*(x+0.5)], [100*chance,100*chance], c="black", linewidth=0.65)
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=0.25)
        ax2.add_patch(rectangle)


    for i, model in enumerate(models):

        if model == "human":
            if grammar in human_stats.keys():
                mean = human_stats[grammar]["ood_rule_2_completion"]
                std = 0
            else:
                continue
        else:
            if model not in stats["ood_rule_2_completion"].groups.keys():
                continue
            mean = stats["ood_rule_2_completion"].get_group(model).mean()
            std = stats["ood_rule_2_completion"].get_group(model).std()

        # ax2.errorbar(x+x_offsets[i], 100*mean, yerr=10*std, fmt='o',c=colors[model], label = model)
        if mean < 0.01:
            ax2.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax2.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax2.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


# set xtick names
ax2.set_xticks(x_pos)
ax2.set_xticklabels(labels)
ax2.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
ax2.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.2)

handles = [mlines.Line2D([], [], color=colors[model], marker='o', label=model.capitalize() if model != "lstm" else model.upper()) for model in models]
ax2.legend(handles=handles, loc="center right", bbox_to_anchor=(1.5, 0.5))#, loc='upper center',  ncol=4)



# plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

plt.savefig("ood_summary.svg")


In [None]:
anbncn_stats["ood_rule_1"].get_group("linear").mean()