# Setup and imports

In [None]:
# imports
import numpy as np
from tueplots import bundles, figsizes
import wandb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd
import matplotlib


import sys

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '.')

In [None]:
from analysis import sweep2df, plot_typography, stats2string, RED, BLUE, rule_stats2string_per_model, grouped_rule_stats
from rule_extrapolation.data import A_token, B_token, OPENING_BRACKET_token, OPENING_PARENTHESIS_token, CLOSING_BRACKET_token, CLOSING_PARENTHESIS_token

In [None]:
USETEX = True

In [None]:
plt.rcParams.update(bundles.icml2022(usetex=USETEX))
# plt.rcParams.update({
#     'text.latex.preamble': [r'\usepackage{amsfonts}', # mathbb
#                             r'\usepackage{amsmath}'] # boldsymbol
# })

In [None]:
plot_typography(usetex=USETEX, small=14, medium=14, big=18)


In [None]:
# Constants
ENTITY = "causal-representation-learning"
PROJECT = "rule_extrapolation"

# W&B API
api = wandb.Api(timeout=400)
runs = api.runs(ENTITY + "/" + PROJECT)

# Data loading

In [None]:
def get_sweep_stats(sweep_id, file_prefix, save=False, load=True, entity=ENTITY, project=PROJECT, pick_max=True):
    api = wandb.Api(timeout=400)
    sweep = api.sweep(f"{entity}/{project}/{sweep_id}")
    filename = f"{file_prefix}_{sweep_id}"
    df,train_loss,val_loss,val_kl,val_accuracy,finised,ood_finised,sos_finised,r1,r3, r2,grammatical,ood_r1, ood_r3, ood_r1_completion,ood_r2,ood_grammatical,sos_r1,sos_r3, sos_r2,sos_grammatical= sweep2df(sweep.runs, filename, save=save, load=load, pick_max=pick_max)
    return df, grouped_rule_stats(df)

## baN

In [None]:
ban_no_transformer_df, _ = get_sweep_stats(sweep_id="frrbkgg0", file_prefix="ban_no_transformer", save=True, load=True)
ban_transformer_df, _ = get_sweep_stats(sweep_id="9b9a0xtn", file_prefix="ban_transformer", save=True, load=True)
ban_lstm_df, _ = get_sweep_stats(sweep_id="vzcozopj", file_prefix="ban_lstm", save=True, load=True)

In [None]:
ban_lin_df, _ = get_sweep_stats(sweep_id="8r2yb016", file_prefix="ban_lin", save=True, load=True)

In [None]:
# filter out linear models (incorrect logs)
ban_no_transformer_df = ban_no_transformer_df[ban_no_transformer_df.model != "linear"]
ban_transformer_df = ban_transformer_df[ban_transformer_df.model != "linear"]
ban_lstm_df = ban_lstm_df[ban_lstm_df.model != "linear"]




In [None]:
ban_df = pd.concat([ban_no_transformer_df, ban_transformer_df, ban_lstm_df, ban_lin_df])
ban_stats = grouped_rule_stats(ban_df)

### sampling next token

In [None]:
ban_sampling_df, ban_sampling_stats = get_sweep_stats(sweep_id="x3d2i6ja", file_prefix="ban_sampling", save=True, load=True, pick_max=False)

## bbaN

In [None]:
bban_transformer_mamba_df, _ = get_sweep_stats(sweep_id="jjjtbh24", file_prefix="bban_transformer_mamba", save=True, load=True)
bban_all_df, _ = get_sweep_stats(sweep_id="unc8bv65", file_prefix="bban_all", save=True, load=True)

In [None]:
bban_lin_df, _ = get_sweep_stats(sweep_id="dnpv6gpm", file_prefix="bban_lin", save=True, load=True)

In [None]:
# filter out linear models (incorrect logs)
bban_transformer_mamba_df = bban_transformer_mamba_df[bban_transformer_mamba_df.model != "linear"]
bban_all_df = bban_all_df[bban_all_df.model != "linear"]


In [None]:
bban_df = pd.concat([bban_transformer_mamba_df, bban_all_df, bban_lin_df])
bban_stats = grouped_rule_stats(bban_df)

### sampling next token

In [None]:
bban_sampling_df, bban_sampling_stats = get_sweep_stats(sweep_id="o9kfhvlu", file_prefix="bban_sampling", save=True, load=True,  pick_max=True)

## aNbN

In [None]:
anbn_transformer_df, _ = get_sweep_stats(sweep_id="vn4yrcl8", file_prefix="anbn_transformer", save=True, load=True)
anbn_lstm_df, _ = get_sweep_stats(sweep_id="t4yzbech", file_prefix="anbn_lstm", save=True, load=True)
anbn_mamba_df, _ = get_sweep_stats(sweep_id="o27zaphz", file_prefix="anbn_mamba", save=True, load=True)
anbn_no_transformer_df, _ = get_sweep_stats(sweep_id="nfrfpkqm", file_prefix="anbn_no_transformer", save=True, load=True)
anbn_lin_df, _ = get_sweep_stats(sweep_id="8lijfxk2", file_prefix="anbn_lin", save=True, load=True)

In [None]:
# filter out linear models (incorrect logs)
anbn_transformer_df = anbn_transformer_df[anbn_transformer_df.model != "linear"]
anbn_lstm_df = anbn_lstm_df[anbn_lstm_df.model != "linear"]
anbn_mamba_df = anbn_mamba_df[anbn_mamba_df.model != "linear"]
anbn_no_transformer_df = anbn_no_transformer_df[anbn_no_transformer_df.model != "linear"]

In [None]:
anbn_df = pd.concat([anbn_transformer_df, anbn_lstm_df, anbn_mamba_df, anbn_no_transformer_df, anbn_lin_df])
anbn_stats = grouped_rule_stats(anbn_df)

### sampling next token

In [None]:
anbn_sampling_df, anbn_sampling_stats = get_sweep_stats(sweep_id="na40gehn", file_prefix="anbn_sampling", save=True, load=True,  pick_max=True)

### aNbN parity

In [None]:
anbn_parity_df, anbn_parity_stats = get_sweep_stats(sweep_id="9qz5sdef", file_prefix="anbn_parity", save=True, load=False,  pick_max=True)


## aNbNcN

In [None]:
anbncn_df, _ = get_sweep_stats(sweep_id="6m1qb70e", file_prefix="anbncn", save=True, load=True)
anbncn_lin_mamba_df, _ = get_sweep_stats(sweep_id="vtji6cx4", file_prefix="anbncn_lin_mamba", save=True, load=True)

In [None]:
anbncn_lin_df, _ = get_sweep_stats(sweep_id="p135j0eg", file_prefix="anbncn_lin", save=True, load=True)

In [None]:
anbncn_df = anbncn_df[anbncn_df.model != "linear"]
anbncn_lin_mamba_df = anbncn_lin_mamba_df[anbncn_lin_mamba_df.model != "linear"]

In [None]:
anbncn_df_merged = pd.concat([anbncn_df, anbncn_lin_mamba_df, anbncn_lin_df])
anbncn_stats = grouped_rule_stats(anbncn_df_merged)


### sampling next token

In [None]:
anbncn_sampling_df, anbncn_sampling_stats = get_sweep_stats(sweep_id="ha3dnqdt", file_prefix="anbncn_sampling", save=True, load=True,  pick_max=True)

## Dyck

In [None]:
dyck_df, _ = get_sweep_stats(sweep_id="eruf1l2q", file_prefix="dyck", save=True, load=True)
dyck_df = dyck_df[dyck_df.model != "linear"]

In [None]:
dyck_lin_df, _ = get_sweep_stats(sweep_id="gw4fwwsr", file_prefix="dyck_lin", save=True, load=True)

In [None]:
dyck_df = pd.concat([dyck_df, dyck_lin_df])
dyck_stats = grouped_rule_stats(dyck_df)

### sampling next token

In [None]:
dyck_sampling_df, dyck_sampling_stats = get_sweep_stats(sweep_id="i6um5idm", file_prefix="dyck_sampling", save=True, load=True,  pick_max=True)


## Context-sensitive Dyck

In [None]:
cs_dyck_df, _ = get_sweep_stats(sweep_id="6dunl50v", file_prefix="cs_dyck", save=True, load=True)
cs_dyck_lstm_df, _ = get_sweep_stats(sweep_id="c9vbbut5", file_prefix="cs_dyck_lstm", save=True, load=True)

cs_dyck_df_merged = pd.concat([cs_dyck_df, cs_dyck_lstm_df])
cs_dyck_stats = grouped_rule_stats(cs_dyck_df_merged)


### Sampling next token

In [None]:
cs_dyck_sampling_df, cs_dyck_sampling_stats = get_sweep_stats(sweep_id="81d37nzc", file_prefix="cs_dyck_sampling", save=True, load=True,  pick_max=True)


## xLSTM

In [None]:
xlstm_ban_df, xlstm_ban_stats = get_sweep_stats(sweep_id="bg56nj7q", file_prefix="xlstm_ban", save=True, load=False)

In [None]:
xlstm_bban_df, xlstm_bban_stats = get_sweep_stats(sweep_id="5vnntqcp", file_prefix="xlstm_bban", save=True, load=True)

In [None]:
xlstm_anbn_df, xlstm_anbn_stats = get_sweep_stats(sweep_id="0xws3drd", file_prefix="xlstm_anbn", save=True, load=True)

In [None]:
xlstm_anbncn_df, xlstm_anbncn_stats = get_sweep_stats(sweep_id="krvkg6dj", file_prefix="xlstm_anbncn", save=True, load=True)

In [None]:
xlstm_dyck_df, xlstm_dyck_stats = get_sweep_stats(sweep_id="rp0zyg2c", file_prefix="xlstm_dyck", save=True, load=True)

In [None]:
xlstm_cs_dyck_df, xlstm_cs_dyck_stats = get_sweep_stats(sweep_id="7ebw68ab", file_prefix="xlstm_cs_dyck", save=True, load=True)


In [None]:
ban_df = pd.concat([ban_df, xlstm_ban_df])
ban_stats = grouped_rule_stats(ban_df)

bban_df = pd.concat([bban_df, xlstm_bban_df])
bban_stats = grouped_rule_stats(bban_df)

anbn_df = pd.concat([anbn_df, xlstm_anbn_df])
anbn_stats = grouped_rule_stats(anbn_df)

anbncn_df_merged = pd.concat([anbncn_df_merged, xlstm_anbncn_df])
anbncn_stats = grouped_rule_stats(anbncn_df_merged)

dyck_df = pd.concat([dyck_df, xlstm_dyck_df])
dyck_stats = grouped_rule_stats(dyck_df)

cs_dyck_df_merged = pd.concat([cs_dyck_df_merged, xlstm_cs_dyck_df])
cs_dyck_stats = grouped_rule_stats(cs_dyck_df_merged)

## Hyperparameter search

In [None]:
hyper_ban_df, hyper_ban_stats = get_sweep_stats(sweep_id="nza9ka3b", file_prefix="hyper_ban", save=True, load=True)

In [None]:
hyper_bban_df, hyper_bban_stats = get_sweep_stats(sweep_id="6fpd9uqg", file_prefix="hyper_bban", save=True, load=True)

In [None]:
hyper_anbn_df, hyper_anbn_stats = get_sweep_stats(sweep_id="amsk1ba7", file_prefix="hyper_anbn", save=True, load=True)

In [None]:
hyper_anbncn_df, hyper_anbncn_stats = get_sweep_stats(sweep_id="5prt0zmv", file_prefix="hyper_anbncn", save=True, load=True)

In [None]:
hyper_dyck_df, hyper_dyck_stats = get_sweep_stats(sweep_id="2r7j2cfm", file_prefix="hyper_dyck", save=True, load=True)

In [None]:
hyper_cs_dyck_df, hyper_cs_dyck_stats = get_sweep_stats(sweep_id="nakt9wnj", file_prefix="hyper_cs_dyck", save=True, load=True)



## Transformer size ablation

In [None]:
transformer_size_ban_df, transformer_size_ban_stats = get_sweep_stats(sweep_id="r6gelh9k", file_prefix="transformer_size_ban", save=True, load=True)

In [None]:
transformer_size_bban_df, transformer_size_bban_stats = get_sweep_stats(sweep_id="53nmwmq1", file_prefix="transformer_size_bban", save=True, load=True)

In [None]:
transformer_size_anbn_df, transformer_size_anbn_stats = get_sweep_stats(sweep_id="ddfjwbsl", file_prefix="transformer_size_anbn", save=True, load=True)

In [None]:
transformer_size_anbncn_df, transformer_size_anbncn_stats = get_sweep_stats(sweep_id="yv9ajwdv", file_prefix="transformer_size_anbncn", save=True, load=True)

In [None]:
transformer_size_dyck_df, transformer_size_dyck_stats = get_sweep_stats(sweep_id="v672gglg", file_prefix="transformer_size_dyck", save=True, load=True)


## Number of all runs

In [None]:
len(ban_df) + len(ban_sampling_df) + len(bban_df) + len(bban_sampling_df) + len(anbn_df) + len(anbn_sampling_df) + len(anbn_parity_df) + len(anbncn_df_merged) + len(anbncn_sampling_df) + len(dyck_df) + len(dyck_sampling_df) + len(cs_dyck_df_merged) + len(cs_dyck_sampling_df) + len(xlstm_ban_df) + len(xlstm_bban_df) + len(xlstm_anbn_df) + len(xlstm_anbncn_df) + len(xlstm_dyck_df) + len(xlstm_cs_dyck_df) + len(hyper_ban_df) + len(hyper_bban_df) + len(hyper_anbn_df) + len(hyper_anbncn_df) + len(hyper_dyck_df) + len(hyper_cs_dyck_df) + len(transformer_size_ban_df) + len(transformer_size_bban_df) + len(transformer_size_anbn_df) + len(transformer_size_anbncn_df) + len(transformer_size_dyck_df)

## Human study

In [None]:
human_df = pd.read_excel("human_study.xlsx")

# fill nan with empty string
human_df = human_df.fillna("")

# fill "already completed" with empty string
human_df = human_df.replace("already completed", "")

In [None]:
prompts = list(human_df.columns[1:])
# remove whitespace from beetween characters in the prompts
prompts = ["".join(prompt.split()) for prompt in prompts]

In [None]:
# overwrite columns names in df
human_df.columns = ["timestamp"] +prompts

In [None]:
A = A_token.item()
B = B_token.item()
OB = OPENING_BRACKET_token.item()
OP = OPENING_PARENTHESIS_token.item()
CB = CLOSING_BRACKET_token.item()
CP = CLOSING_PARENTHESIS_token.item()

def char2token(char):
    char = char.upper()
    if char == "A":
        return A
    elif char == "B":
        return B
    elif char == "(":
        return OP
    elif char == ")":
        return CP
    elif char == "[":
        return OB
    elif char == "]":
        return CB
    else:
        return None

In [None]:
# tokenize prompts
prompts_tokenized = [[char2token(c) for c in prompt] for prompt in prompts]



In [None]:
tokenized_human_df = human_df.copy()

In [None]:
for prompt, prompt_tokenized in zip(prompts, prompts_tokenized):

    # get the column by prompt
    col = human_df[prompt]
    col_stripped = ["".join(prompt.split()) for prompt in col]
    col_tokenized = [[char2token(c) for c in prompt] for prompt in col_stripped]

    tokenized_human_df[prompt] = col_tokenized



In [None]:
anbn_r1_human = []
anbn_r2_completion_human = []

dyck_r1_human = []
dyck_r2_completion_human = []

ban_r1_human = []
ban_r2_completion_human = []
from rule_extrapolation.data import check_same_number_as_bs, check_as_before_bs, check_even_number_of_as, check_begins_with_b, check_matched_brackets, check_matched_parentheses
import torch

for idx, (prompt, prompt_tokenized) in enumerate(zip(prompts, prompts_tokenized)):
    # get the column by prompt
    col = tokenized_human_df[prompt]

    # iterate over the rows
    for completion in col:

        if None in completion:
            continue

        # add the tokenized prompt
        completed_prompt = torch.tensor(prompt_tokenized.copy() + completion)
        completion = torch.tensor(completion)
        # print(completion)

        if idx < 4: # anbn, idx 4 is in-distribution
            anbn_r1_human.append(check_same_number_as_bs(completed_prompt))
            anbn_r2_completion_human.append(True if len(completion)==0 else check_as_before_bs(completion))
        elif 4 < idx <=9    : # dyck
            dyck_r1_human.append(check_matched_brackets(completed_prompt[2:]))
            dyck_r2_completion_human.append(True if len(completion)==0 else check_matched_parentheses(completion))
        else: # ban
            ban_r1_human.append(check_even_number_of_as(completed_prompt))
            ban_r2_completion_human.append(True if len(completion)==0 else check_begins_with_b(completion))



In [None]:
# print accuracies
print(f"anbn R1: {np.mean(anbn_r1_human)}")
print(f"anbn R2: {np.mean(anbn_r2_completion_human)}")
print(f"dyck R1: {np.mean(dyck_r1_human)}"
)
print(f"dyck R2: {np.mean(dyck_r2_completion_human)}")
print(f"ban R1: {np.mean(ban_r1_human)}")
print(f"ban R2: {np.mean(ban_r2_completion_human)}")

In [None]:
human_stats = {}
human_stats["baN"] = {
    "ood_rule_1": np.mean(ban_r1_human),
    "ood_rule_2_completion": np.mean(ban_r2_completion_human)
}

human_stats["aNbN"] = {
    "ood_rule_1": np.mean(anbn_r1_human),
    "ood_rule_2_completion": np.mean(anbn_r2_completion_human)
}

# human_stats["Dyck"] = {
#     "ood_rule_1": np.mean(dyck_r1_human),
#     "ood_rule_2_completion": np.mean(dyck_r2_completion_human)
# }

## Chance levels

In [None]:
chance_stats = {}
chance_stats["baN"] = {
    "ood_rule_1": 0.5,
    "ood_rule_2_completion": 1./3
}

chance_stats["bbaN"] = {
    "ood_rule_1": 0.5,
    "ood_rule_2_completion": 0.75
}


chance_stats["aNbN"] = {
    "ood_rule_1": 0.154,
    "ood_rule_2_completion": 0.4445
}
chance_stats["Dyck"] = {
    "ood_rule_1": 0.1273,
    "ood_rule_2_completion": 0.382
}

chance_stats["CS Dyck"] = {
    "ood_rule_1": 0.1273,
    "ood_rule_2_completion": 0.382
}

chance_stats["aNbNcN"] = {
    "ood_rule_1": 0.00334,
    "ood_rule_2_completion": 0.5925
}

chance_stats["CS Dyck"] = {
    "ood_rule_1": 0.1273,
    "ood_rule_2_completion": 0.382
}



# Plots

## helper functions


In [None]:
def plot_loss_vs_rules(df, stats, cmap="coolwarm", TICK_PADDING=2, LABELPAD=1, filename=None):

    colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red"
        }


    fig = plt.figure(figsize=figsizes.icml2022_full(nrows=1, ncols=2)['figure.figsize'])
    ax = fig.add_subplot(121)
    ax.grid(True, which="both", ls="-.")
    ax.set_axisbelow(True)
    for model in stats["val_loss"].groups.keys():

        im = ax.scatter(df[df.model == model].min_val_loss,
                        100 * df[df.model == model].ood_rule_1_accuracy4min_val_loss, c=colors[model], label=model.capitalize() if model != "lstm" else model.upper())
    ax.set_ylabel("R1 \%", labelpad=LABELPAD)
    ax.set_xlabel("Minimum test loss", labelpad=LABELPAD)
    # plt.legend()
    ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
    ax = fig.add_subplot(122)
    ax.grid(True, which="both", ls="-.")
    ax.set_axisbelow(True)
    for model in stats["val_loss"].groups.keys():
        im = ax.scatter(df[df.model == model].min_val_loss,
                        100 * df[df.model == model].ood_rule_2_completion_accuracy4min_val_loss, c=colors[model], label=model.capitalize() if model != "lstm" else model.upper())
    ax.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
    ax.set_xlabel("Minimum test loss", labelpad=LABELPAD)
    plt.legend(loc="center right", bbox_to_anchor=(1.75, 0.5))
    ax.tick_params(axis='both', which='major', pad=TICK_PADDING)

    if filename is not None:
        plt.savefig(f"{filename}.svg")

## baN

In [None]:
plot_loss_vs_rules(ban_df, ban_stats, filename="ban_loss_vs_rules")

In [None]:
rule_stats2string_per_model(ban_stats, include_r2=False)

## bbaN

In [None]:
plot_loss_vs_rules(bban_df, bban_stats, filename="bban_loss_vs_rules")

In [None]:
rule_stats2string_per_model(bban_stats)

## aNbN

In [None]:
plot_loss_vs_rules(anbn_df, anbn_stats, filename="anbn_loss_vs_rules")

In [None]:
plot_loss_vs_rules(anbn_parity_df, anbn_parity_stats, filename="anbn_loss_vs_rules")

In [None]:
rule_stats2string_per_model(anbn_stats)

## aNbNcN

In [None]:
plot_loss_vs_rules(anbncn_df, anbncn_stats, filename="anbncn_loss_vs_rules")

In [None]:
rule_stats2string_per_model(anbncn_stats)

## Dyck

In [None]:
plot_loss_vs_rules(dyck_df, dyck_stats, filename="dyck_loss_vs_rules")

In [None]:
rule_stats2string_per_model(dyck_stats)

## Context-sensitive Dyck

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"


stats_dict = {
    # "baN": grouped_rule_stats(pd.concat([ban_df, xlstm_ban_df])),
    # "bbaN": grouped_rule_stats(pd.concat([bban_df, xlstm_bban_df])),
    # "aNbN": grouped_rule_stats(pd.concat([anbn_df, xlstm_anbn_df])),
    # "aNbNcN": grouped_rule_stats(pd.concat([anbncn_df_merged, xlstm_anbncn_df])),
    # "Dyck": grouped_rule_stats(pd.concat([dyck_df, xlstm_dyck_df])),
    "CS Dyck": grouped_rule_stats(pd.concat([cs_dyck_df_merged, xlstm_cs_dyck_df]))
}

labels = [
    # r"$b\alpha$",
    # r"$b^na^{2m}$",
    # r"$a^nb^n$",
    # r"$a^nb^nc^n$",
    # "Dyck",
    "CS Dyck",
]

colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red",
    # "human" : "black",
    "xlstm": "purple"
        }


x_pos = list(range(len(stats_dict)))
x_offsets = [-0.35, -0.175, 0., .175, 0.35]
x_factor = 1/ len(stats_dict)
width = 0.35
x_stretch = 3
x_pos = [x*x_stretch for x in x_pos]

fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=3, tight_layout=True, rel_width=3)['figure.figsize'])

ax = fig.add_subplot(121)


models =  colors.keys()

print("--------------------")
print("R1")
print("--------------------")

for x, (grammar, stats) in enumerate(stats_dict.items()):
    #
    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_1"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=.35)
        ax.add_patch(rectangle)

    for i, model in enumerate(models):

        if model not in stats["ood_rule_1"].groups.keys():
            print(f"Model {model} not in {grammar}")
            continue
        mean = stats["ood_rule_1"].get_group(model).mean()
        std = stats["ood_rule_1"].get_group(model).std()

        print(f"{model=} {mean} {std}")



        if mean < 0.01:
            ax.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)

# set xtick names
ax.set_xticks(x_pos)
ax.set_xticklabels(labels)
ax.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)

ax2 = fig.add_subplot(122)
print("--------------------")
print("R2 completion")
print("--------------------")
for x, (grammar, stats) in enumerate(stats_dict.items()):

    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_2_completion"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=0.35)
        ax2.add_patch(rectangle)


    for i, model in enumerate(models):


        if model not in stats["ood_rule_2_completion"].groups.keys():
            continue
        mean = stats["ood_rule_2_completion"].get_group(model).mean()
        std = stats["ood_rule_2_completion"].get_group(model).std()

        print(f"{model=} {mean} {std}")

        if mean < 0.01:
            ax2.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax2.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax2.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


# set xtick names
ax2.set_xticks(x_pos)
ax2.set_xticklabels(labels)
ax2.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
ax2.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)

handles = [mlines.Line2D([], [], color=colors[model], marker='o', label=(model.capitalize() if model != "xlstm" else "xLSTM") if model != "lstm" else model.upper()) for model in models] + [matplotlib.patches.Patch(color='gray', alpha=0.35, label='Chance-level')]
ax2.legend(handles=handles, loc="center right", bbox_to_anchor=(1.5, 0.5))#, loc='upper center',  ncol=4)



# plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
ax.set_ylim(0, 100)
ax2.set_ylim(0, 100)

plt.savefig("ood_summary_cs_dyck.svg")


In [None]:
cs_dyck_df.model.unique()

In [None]:
rule_stats2string_per_model(grouped_rule_stats(pd.concat([cs_dyck_df_merged, xlstm_cs_dyck_df])), include_r2=True)

## Plot for all languages

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"


stats_dict = {
    "baN": grouped_rule_stats(pd.concat([ban_df, xlstm_ban_df])),
    "bbaN": grouped_rule_stats(pd.concat([bban_df, xlstm_bban_df])),
    "aNbN": grouped_rule_stats(pd.concat([anbn_df, xlstm_anbn_df])),
    "aNbNcN": grouped_rule_stats(pd.concat([anbncn_df_merged, xlstm_anbncn_df])),
    "Dyck": grouped_rule_stats(pd.concat([dyck_df, xlstm_dyck_df])),
    "CS Dyck": grouped_rule_stats(pd.concat([cs_dyck_df_merged, xlstm_cs_dyck_df]))
}

labels = [
    r"$b\alpha$",
    r"$b^na^{2m}$",
    r"$a^nb^n$",
    r"$a^nb^nc^n$",
    "Dyck",
    "CS Dyck",
]

colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red",
    # "human" : "black",
    "xlstm": "purple"
        }


x_pos = list(range(len(stats_dict)))
x_offsets = [-0.35, -0.175, 0., .175, 0.35]
x_factor = 1/ len(stats_dict)
width = 0.35
x_stretch = 3
x_pos = [x*x_stretch for x in x_pos]

fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=3, tight_layout=True, rel_width=3)['figure.figsize'])

ax = fig.add_subplot(121)


models =  colors.keys()
for x, (grammar, stats) in enumerate(stats_dict.items()):
    #
    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_1"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=.35)
        ax.add_patch(rectangle)

    for i, model in enumerate(models):

        if model not in stats["ood_rule_1"].groups.keys():
            print(f"Model {model} not in {grammar}")
            continue
        mean = stats["ood_rule_1"].get_group(model).mean()
        std = stats["ood_rule_1"].get_group(model).std()



        if mean < 0.01:
            ax.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)

# set xtick names
ax.set_xticks(x_pos)
ax.set_xticklabels(labels)
ax.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)

ax2 = fig.add_subplot(122)

for x, (grammar, stats) in enumerate(stats_dict.items()):

    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_2_completion"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=0.35)
        ax2.add_patch(rectangle)


    for i, model in enumerate(models):


        if model not in stats["ood_rule_2_completion"].groups.keys():
            continue
        mean = stats["ood_rule_2_completion"].get_group(model).mean()
        std = stats["ood_rule_2_completion"].get_group(model).std()

        if mean < 0.01:
            ax2.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax2.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax2.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


# set xtick names
ax2.set_xticks(x_pos)
ax2.set_xticklabels(labels)
ax2.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
ax2.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)

handles = [mlines.Line2D([], [], color=colors[model], marker='o', label=(model.capitalize() if model != "xlstm" else "xLSTM") if model != "lstm" else model.upper()) for model in models] + [matplotlib.patches.Patch(color='gray', alpha=0.35, label='Chance-level')]
ax2.legend(handles=handles, loc="center right", bbox_to_anchor=(1.5, 0.5))#, loc='upper center',  ncol=4)



# plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
ax.set_ylim(0, 100)
ax2.set_ylim(0, 100)

plt.savefig("ood_summary_xlstm.svg")


In [None]:
[len(x) for x in stats_dict.values()]

## Plot for sampling

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"


stats_dict = {
    "baN": ban_sampling_stats,
    "bbaN": bban_sampling_stats,
    "aNbN": anbn_sampling_stats,
    "aNbNcN": anbncn_sampling_stats,
    "Dyck": dyck_sampling_stats,
}

labels = [
    r"$b\alpha$",
    r"$b^na^{2m}$",
    r"$a^nb^n$",
    r"$a^nb^nc^n$",
    "Dyck"
]

colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red",
    # "human" : "black",
    # "chance": "purple"
        }


x_pos = list(range(len(stats_dict)))
x_offsets = [-0.3, -0.1, .1, 0.3]
x_factor = 1/ len(stats_dict)
width = 0.35
x_stretch = 3
x_pos = [x*x_stretch for x in x_pos]

fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=3, tight_layout=True, rel_width=2)['figure.figsize'])

ax = fig.add_subplot(121)
# ax.grid(True, which="both", ls="-.")
# ax.set_axisbelow(True)

models =  colors.keys()
for x, (grammar, stats) in enumerate(stats_dict.items()):
    #
    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_1"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=.35)
        ax.add_patch(rectangle)

    for i, model in enumerate(models):


        if model not in stats["ood_rule_1"].groups.keys():
            continue
        mean = stats["ood_rule_1"].get_group(model).mean()
        std = stats["ood_rule_1"].get_group(model).std()



        if mean < 0.01:
            ax.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)

# set xtick names
ax.set_xticks(x_pos)
ax.set_xticklabels(labels)
ax.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)

ax2 = fig.add_subplot(122)

for x, (grammar, stats) in enumerate(stats_dict.items()):

    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_2_completion"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=0.35)
        ax2.add_patch(rectangle)


    for i, model in enumerate(models):


        if model not in stats["ood_rule_2_completion"].groups.keys():
            continue
        mean = stats["ood_rule_2_completion"].get_group(model).mean()
        std = stats["ood_rule_2_completion"].get_group(model).std()

        if mean < 0.01:
            ax2.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax2.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax2.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


# set xtick names
ax2.set_xticks(x_pos)
ax2.set_xticklabels(labels)
ax2.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
ax2.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)

handles = [mlines.Line2D([], [], color=colors[model], marker='o', label=model.capitalize() if model != "lstm" else model.upper()) for model in models] + [matplotlib.patches.Patch(color='gray', alpha=0.35, label='Chance-level')]
ax2.legend(handles=handles, loc="center right", bbox_to_anchor=(1.5, 0.5))#, loc='upper center',  ncol=4)



# plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
ax.set_ylim(0, 100)
ax2.set_ylim(0, 100)

plt.savefig("ood_summary_sampling.svg")


In [None]:
dyck_sampling_df.model

## aNbN parity

In [None]:
rule_stats2string_per_model(grouped_rule_stats(anbn_parity_df), plot=("val_loss", "rule_1", "rule_3", "ood_rule_1", "ood_rule_3"), include_r2=False)

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"


stats_dict = {
    "aNbN": grouped_rule_stats(anbn_parity_df),
}

labels = [
    r"$a^nb^n$",
]

colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red",
    # "human" : "black",
    "xlstm": "purple"
        }




fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=1, tight_layout=True, rel_width=1)['figure.figsize'])

ax = fig.add_subplot(111)




for i, model in enumerate(models):
    model_runs = anbn_parity_df[anbn_parity_df.model == model]
    print("------------------")
    print(f"(OOD R1){model=} {100*model_runs.ood_rule_1_accuracy4min_val_loss.mean()} {100*model_runs.ood_rule_1_accuracy4min_val_loss.std()}")
    print(f"(OOD R3){model=} {100*model_runs.ood_rule_1_accuracy4min_val_loss.mean()} {100*model_runs.ood_rule_1_accuracy4min_val_loss.std()}")
    print("------------------")


    ax.scatter(100*model_runs.ood_rule_1_accuracy4min_val_loss, 100*model_runs.ood_rule_3_accuracy4min_val_loss, c=colors[model], label=model.capitalize() if model != "lstm" else model.upper())



ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)
ax.set_xlabel("R3 (\%)", labelpad=LABELPAD)

# set xtick names


handles = [mlines.Line2D([], [], color=colors[model], marker='o', label=(model.capitalize() if model != "xlstm" else "xLSTM") if model != "lstm" else model.upper()) for model in models]
ax.legend(handles=handles, loc="center right", bbox_to_anchor=(1.5, 0.5))#, loc='upper center',  ncol=4)



# plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
ax.set_ylim(0, 100)
ax.set_xlim(0, 100)
plt.savefig("ood_anbn_parity.svg")

## Hyperparameter search plot

In [None]:
TICK_PADDING = 2
LABELPAD = 1
cmap = "coolwarm"


labels = [
    r"$b\alpha$",
    r"$b^na^{2m}$",
    r"$a^nb^n$",
    r"$a^nb^nc^n$",
    # "Dyck",
    # "CS Dyck",
]

colors = {
        "transformer": "tab:blue",
        "lstm": "tab:orange",
        "linear": "tab:green",
        "mamba": "tab:red",
    # "human" : "black",
    # "chance": "purple"
        }



fig = plt.figure(figsize=figsizes.neurips2022(nrows=3, ncols=3, tight_layout=True, rel_width=2)['figure.figsize'])

ax = fig.add_subplot(321)
ax2 = fig.add_subplot(322)
ax3 = fig.add_subplot(323)
ax4 = fig.add_subplot(324)
ax5 = fig.add_subplot(325)
ax6 = fig.add_subplot(326)

models =  colors.keys()
optimizers = hyper_ban_df.optimizer.unique()
lrs = hyper_ban_df.lr.unique()

stats_dict_hyper = {
    "baN": grouped_rule_stats(hyper_ban_df, ["optimizer", "lr", "model"]),
    "bbaN": grouped_rule_stats(hyper_bban_df, ["optimizer", "lr", "model"]),
    "aNbN": grouped_rule_stats(hyper_anbn_df, ["optimizer", "lr", "model"]),
    "aNbNcN": grouped_rule_stats(hyper_anbncn_df, ["optimizer", "lr", "model"]),
    # "Dyck": grouped_rule_stats(hyper_dyck_df, ["optimizer", "lr", "model"]),
    # "CS Dyck": cs_dyck_stats,
}


x_pos = list(range(len(stats_dict_hyper)))
x_offsets = [-0.3, -0.1, .1, 0.3]
x_factor = 1/ len(stats_dict_hyper)
width = 0.35
x_stretch = 3
x_pos = [x*x_stretch for x in x_pos]

# plot_row(ax, None, stats_dict_hyper, optimizers[0], lrs[0],"optimizer", "lr")


stats_dict = {
    "baN": grouped_rule_stats(pd.concat([ban_df, xlstm_ban_df])),
    "bbaN": grouped_rule_stats(pd.concat([bban_df, xlstm_bban_df])),
    "aNbN": grouped_rule_stats(pd.concat([anbn_df, xlstm_anbn_df])),
    "aNbNcN": grouped_rule_stats(pd.concat([anbncn_df_merged, xlstm_anbncn_df])),
    # "Dyck": grouped_rule_stats(pd.concat([dyck_df, xlstm_dyck_df])),
    # "CS Dyck": grouped_rule_stats(pd.concat([cs_dyck_df_merged, xlstm_cs_dyck_df]))
}

for x, (grammar, stats) in enumerate(stats_dict.items()):
    #
    if grammar in chance_stats.keys():
        chance = chance_stats[grammar]["ood_rule_1"]
        rectangle = matplotlib.patches.Rectangle((x_stretch*(x-0.5),0), width=x_stretch, height=100*chance, color='gray', alpha=.35)
        ax.add_patch(rectangle)

    for i, model in enumerate(models):

        if model not in stats["ood_rule_1"].groups.keys():
            print(f"Model {model} not in {grammar}")
            continue
        mean = stats["ood_rule_1"].get_group(model).mean()
        std = stats["ood_rule_1"].get_group(model).std()

        if mean < 0.01:
            ax.errorbar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o",  label=model, c=colors[model])
        else:
            ax.bar(x_stretch*(x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width,  label=model, color=colors[model])


    ax.axvline(x_stretch*(x+0.5), color='black', linestyle='--', linewidth=0.5)


ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)

# set xtick names
ax.set_xticks(x_pos)
ax.set_xticklabels(labels)
ax.set_xlim(x_stretch*x_offsets[0]-.35,x_stretch*(len(stats_dict)-1+x_offsets[-1])+.25)




# plt.legend()
ax.tick_params(axis='both', which='major', pad=TICK_PADDING)
ax.set_ylim(0, 100)
ax.set_title(f"Fig. 1: optimizer: {optimizers[0]}, lr: {lrs[0]}")


plot_row(ax2, None, stats_dict_hyper, optimizers[1], lrs[0],"optimizer", "lr")
plot_row(ax3, None, stats_dict_hyper, optimizers[0], lrs[1],"optimizer", "lr")
plot_row(ax4, None, stats_dict_hyper, optimizers[1], lrs[1],"optimizer", "lr")
plot_row(ax5, None, stats_dict_hyper, optimizers[0], lrs[2],"optimizer", "lr")
plot_row(ax6, None, stats_dict_hyper, optimizers[1], lrs[2],"optimizer", "lr")

# handles = [mlines.Line2D([], [], color=colors[model], marker='o',
#                                  label=model.capitalize() if model != "lstm" else model.upper()) for model in models] + [
#                       matplotlib.patches.Patch(color='gray', alpha=0.35, label='Chance-level')]
# ax6.legend(handles=handles, loc="lower center", ncol=6)


plt.savefig("ood_summary_hyper.svg")


In [None]:

def plot_row(ax, ax2, stats_dict, row_key, col_key, row_label, col_label):

    for x, (grammar, stats) in enumerate(stats_dict.items()):
        #
        if grammar in chance_stats.keys():
            chance = chance_stats[grammar]["ood_rule_1"]
            rectangle = matplotlib.patches.Rectangle((x_stretch * (x - 0.5), 0), width=x_stretch, height=100 * chance,
                                                     color='gray', alpha=.35)
            ax.add_patch(rectangle)

        for i, model in enumerate(models):
            # if model not in stats["ood_rule_1"].groups.keys():
            #     continue
            try:
                mean = stats["ood_rule_1"].get_group((row_key, col_key, model)).mean()
                std = stats["ood_rule_1"].get_group((row_key, col_key, model)).std()
            except:
                continue

            if mean < 0.01:
                ax.errorbar(x_stretch * (x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o", label=model,
                            c=colors[model])
            else:
                ax.bar(x_stretch * (x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width, label=model,
                       color=colors[model])

        ax.axvline(x_stretch * (x + 0.5), color='black', linestyle='--', linewidth=0.5)
    ax.set_ylabel("R1 (\%)", labelpad=LABELPAD)
    # set xtick names
    ax.set_xticks(x_pos)
    ax.set_xticklabels(labels)
    ax.set_xlim(x_stretch * x_offsets[0] - .35, x_stretch * (len(stats_dict) - 1 + x_offsets[-1]) + .25)
    ax.set_title(f"{row_label}: {row_key}, {col_label}: {col_key}")
    ax.set_ylim(0, 100)

    if ax2 is not None:
        for x, (grammar, stats) in enumerate(stats_dict.items()):

            if grammar in chance_stats.keys():
                chance = chance_stats[grammar]["ood_rule_2_completion"]
                rectangle = matplotlib.patches.Rectangle((x_stretch * (x - 0.5), 0), width=x_stretch, height=100 * chance,
                                                         color='gray', alpha=0.35)
                ax2.add_patch(rectangle)

            for i, model in enumerate(models):

                # if model not in stats["ood_rule_2_completion"].groups.keys():
                #     continue
                try:
                    mean = stats["ood_rule_2_completion"].get_group((row_key, col_key, model)).mean()
                    std = stats["ood_rule_2_completion"].get_group((row_key, col_key, model)).std()
                except:
                    continue

                if mean < 0.01:
                    ax2.errorbar(x_stretch * (x + x_offsets[i]), 100 * mean, yerr=10 * std, fmt="o", label=model,
                                 c=colors[model])
                else:
                    ax2.bar(x_stretch * (x + x_offsets[i]), 100 * mean, yerr=10 * std, width=width, label=model,
                            color=colors[model])

            ax2.axvline(x_stretch * (x + 0.5), color='black', linestyle='--', linewidth=0.5)
        # set xtick names
        ax2.set_xticks(x_pos)
        ax2.set_xticklabels(labels)
        ax2.set_ylabel("R2 completion (\%)", labelpad=LABELPAD)
        ax2.set_xlim(x_stretch * x_offsets[0] - .35, x_stretch * (len(stats_dict) - 1 + x_offsets[-1]) + .25)
        handles = [mlines.Line2D([], [], color=colors[model], marker='o',
                                 label=model.capitalize() if model != "lstm" else model.upper()) for model in models] + [
                      matplotlib.patches.Patch(color='gray', alpha=0.35, label='Chance-level')]
        ax2.legend(handles=handles, loc="center right", bbox_to_anchor=(1.5, 0.5))  #, loc='upper center',  ncol=4)
        # plt.legend()
        ax2.set_title(f"{row_label}: {row_key}, {col_label}: {col_key}")
        ax2.set_ylim(0, 100)

    ax.tick_params(axis='both', which='major', pad=TICK_PADDING)



## Transformer size ablation plot

In [None]:
colors = {
        "transformer": "tab:blue",
        # "lstm": "tab:orange",
        # "linear": "tab:green",
        # "mamba": "tab:red",
    # "human" : "black",
    # "xlstm": "purple"
        }


labels = [
    r"$b\alpha$",
    r"$b^na^{2m}$",
    r"$a^nb^n$",
    r"$a^nb^nc^n$",
    "Dyck"
]


fig = plt.figure(figsize=figsizes.neurips2022(nrows=3, ncols=3, tight_layout=True, rel_width=2)['figure.figsize'])


ax = plt.subplot2grid((3,2), (0,0))
ax2 = plt.subplot2grid((3,2), (0,1))
ax3 =plt.subplot2grid((3,2), (1,0))
ax4 =plt.subplot2grid((3,2), (1,1))
ax5 =plt.subplot2grid((3,2), (2,0))
ax6 =plt.subplot2grid((3,2), (2,1))
# ax7 =plt.subplot2grid((6,2), (3,0))
# ax8 =plt.subplot2grid((6,2), (3,1))
# ax9 =plt.subplot2grid((6,2), (4,0))
# ax10 =plt.subplot2grid((6,2), (4,1))
# ax11=plt.subplot2grid((6,2), (5,0))
# ax12 =plt.subplot2grid((6,2), (5,1))

models =  colors.keys()
num_heads = transformer_size_ban_df.num_heads.unique()
num_decoder_layers = transformer_size_ban_df.num_decoder_layers.unique()

stats_dict_transformer_size = {
    "baN": grouped_rule_stats(transformer_size_ban_df, ["num_heads", "num_decoder_layers", "model"]),
    "bbaN": grouped_rule_stats(transformer_size_bban_df, ["num_heads", "num_decoder_layers", "model"]),
    "aNbN": grouped_rule_stats(transformer_size_anbn_df, ["num_heads", "num_decoder_layers", "model"]),
    "aNbNcN": grouped_rule_stats(transformer_size_anbncn_df, ["num_heads", "num_decoder_layers", "model"]),
    "Dyck": grouped_rule_stats(transformer_size_dyck_df, ["num_heads", "num_decoder_layers", "model"]),
    # "CS Dyck": cs_dyck_stats,
}

x_pos = list(range(len(stats_dict_transformer_size)))
x_offsets = [-0.3, -0.1, .1, 0.3]
x_factor = 1/ len(stats_dict_transformer_size)
width = 0.35
x_stretch = 3
x_pos = [x*x_stretch for x in x_pos]

plot_row(ax, None, stats_dict_transformer_size, num_heads[0], num_decoder_layers[0], "\# Heads", "\# Decoder layers")
plot_row(ax2, None, stats_dict_transformer_size, num_heads[1], num_decoder_layers[0],"\# Heads", "\# Decoder layers")
plot_row(ax3, None, stats_dict_transformer_size, num_heads[0], num_decoder_layers[1],"\# Heads", "\# Decoder layers")
plot_row(ax4, None, stats_dict_transformer_size, num_heads[1], num_decoder_layers[1],"\# Heads", "\# Decoder layers")
plot_row(ax5, None, stats_dict_transformer_size, num_heads[0], num_decoder_layers[2],"\# Heads", "\# Decoder layers")
plot_row(ax6, None, stats_dict_transformer_size, num_heads[1], num_decoder_layers[2],"\# Heads", "\# Decoder layers")

handles = [mlines.Line2D([], [], color=colors[model], marker='o',
                                 label=model.capitalize() if model != "lstm" else model.upper()) for model in models] + [
                      matplotlib.patches.Patch(color='gray', alpha=0.35, label='Chance-level')]
ax3.legend(handles=handles, loc="upper center", ncol=2)

plt.savefig("ood_summary_transformer_size.svg")