In [None]:
import glob
import os
import pandas as pd

# plot heatmap
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import matplotlib

In [None]:
import torch
from analyze import get_dfa_probs as calculate_dfa_probs
from ngram import (
    predict_with_n_gram_back_off,
    prob_distance,
    prob_distance_dfa,
    prob_distance_dfa_ngram,
)

from batched_baum_welch import predict_with_baumwelch

In [None]:
class Vocab:
    def __init__(self, vocab: list):
        self.vocab = vocab
        # inverse vocab
        self.inv_vocab = {v: k for k, v in enumerate(vocab)}

    def get_vocab(self, id):
        return self.vocab[id]

    def get_id(self, char):
        return self.inv_vocab[char]

    def __len__(self):
        return len(self.vocab)


def get_ngram_probs(results, ngram=3, uniform=False, backoff=False, addone=False):
    vocab = Vocab(results[0]["vocab"])
    n_gram_probs = []
    for b in range(len(results)):
        input = results[b]["input"]
        target = [vocab.get_id(t) for t in results[b]["target"]]
        probs = predict_with_n_gram_back_off(
            input,
            N=ngram,
            global_vocab=vocab,
            uniform=uniform,
            backoff=backoff,
            addone=addone,
        )
        n_gram_probs.append(probs)
    return n_gram_probs


def get_baumwelch_probs(results):
    vocab = Vocab(results[0]["vocab"])
    baumwelch_probs = []
    for b in range(len(results)):
        input = results[b]["input"]
        probs = predict_with_baumwelch(input, vocab, max_states=12)
        baumwelch_probs.append(probs)
    return baumwelch_probs


def get_dfa_probs(results):
    vocab = Vocab(results[0]["vocab"])
    dfa_probs = []
    for b in range(len(results)):
        input = results[b]["input"]
        target = [vocab.get_id(t) for t in results[b]["target"]]
        probs = calculate_dfa_probs(input, results[b]["dfa"], vocab=vocab)
        dfa_probs.append(probs)
    return dfa_probs


def get_model_probs(results, softmax=True):
    model_probs = []
    for b in range(len(results)):
        if softmax:
            probs = (
                torch.softmax(torch.tensor(results[b]["probs"]), dim=-1)
                .detach()
                .cpu()
                .numpy()
            )
        else:
            probs = results[b]["probs"]
        model_probs.append(probs)
    return model_probs


import numpy as np


def get_greedy_dfa_accuracy(probs, dfa_probs, offset=0, max_len=None):
    total = 0.0
    correct = 0.0
    for p1, pdfa in zip(probs, dfa_probs):
        if max_len is not None:
            pdfa = pdfa[offset:max_len]
        indices = p1.argmax(axis=-1)[: len(pdfa)]
        correct += (pdfa[np.arange(len(pdfa)), indices] > 0).sum()
        total += len(pdfa)
    return correct / total


EPS = 1e-7


def get_kl(probs, dfa_probs, offset=0, max_len=None):
    total = 0.0
    cross_entropy = 0.0
    for p1, pdfa in zip(probs, dfa_probs):
        # calculate the soft cross-entropy between p1 and pdfa
        if max_len is not None:
            pdfa = pdfa[offset:max_len]
        log_p1 = np.log(p1[: len(pdfa)] + EPS)
        log_pdfa = np.log(pdfa + EPS)
        cross_entropy += -((log_p1 - log_pdfa) * pdfa).sum()
        total += len(pdfa)
    return cross_entropy / total


def get_l1_loss(probs1, probs2, probsdfa, offset=0, max_len=None):
    total = 0.0
    correct = 0.0
    for p1, p2, pdfa in zip(probs1, probs2, probsdfa):
        if max_len is not None:
            pdfa = pdfa[offset:max_len]
        total += len(pdfa)
        correct += np.abs(
            p1[offset : offset + len(pdfa)] - p2[offset : offset + len(pdfa)]
        ).sum()
    return correct / total

In [None]:
# glob all checkpoints
run_folders = glob.glob(
    "experiments/hiddens_*/**/generations/*test_batch/", recursive=True
)
# create a map
name_to_folder = {}
for folder in run_folders:
    folder = folder.replace("//", "/").strip("/")
    subpaths = folder.split("/")
    name = subpaths[2]
    num_examples = subpaths[1].split("_")[1]
    nlayer = "" if subpaths[3] == "generations" else subpaths[3]
    name = f"{num_examples}/{name}/{nlayer}"
    name_to_folder[name] = folder

In [None]:
name_to_folder

In [None]:
# import functools
# import concurrent
# import pickle
# def read_one(fname, probs_only=False):
#     with open(fname, "rb") as f:
#         data =  pickle.load(f)
#         if "hidden_outputs" in data:
#             del data["hidden_outputs"]
#         if "attention_scores" in data:
#             del data["attention_scores"]
#         if "attention_contexts" in data:
#             del data["attention_contexts"]

#     with open(fname, 'wb') as f:
#         pickle.dump(data, f)

#     return data


# def read_parallel(file_names, probs_only=False):
#     reader = functools.partial(read_one, probs_only=probs_only)
#     with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
#         futures = [executor.submit(reader, f) for f in file_names]
#         return [fut.result() for fut in futures]

# for model in name_to_folder.keys():
#     if model.startswith("20000"):
#         print(model)
#         folder = name_to_folder[model]
#         files = glob.glob(f"{folder}/*.pkl")
#         print(files[0])
#         read_parallel(files, probs_only=True)

In [None]:
from probe import get_results

In [None]:
results = {}
for num_examples in (1000, 2500, 5000, 10000, 20000, 40000):
    results[num_examples] = {}
    for model in (
        "transformer",
        "lstm",
        "hyena",
        "h3",
        "s4d",
        "linear_transformer",
        "rwkv",
        "retention",
        "transformer_4",
        "transformer_8",
        "transformer_2",
        "transformer_1",
    ):
        name = f"{num_examples}/{model}/"
        if model in results[num_examples]:
            continue
        try:
            results[num_examples][model] = get_results(
                name_to_folder[name], probs_only=True
            )
        except:
            print(f"Failed to load {name}")
    for model in ("2gram", "3gram"):
        file = f"experiments/hiddens_{num_examples}/{model}/probs.pkl"
        try:
            results[num_examples][model] = pickle.load(open(file, "rb"))
        except Exception as e:
            print(e)
            print(f"Failed to load {file}")

In [None]:
for num_examples in (1000, 2500, 5000, 10000, 20000, 40000):
    results[num_examples]["dfa"] = get_dfa_probs(results[num_examples]["transformer"])

In [None]:
results[2500]["2gram"][0]['probs'].shape

In [None]:
get_ngram_probs(
        results[1000]["transformer"],
        ngram=3,
        uniform=False,
        backoff=True,
        addone=False,
    ).shape

In [None]:
gram2 = get_ngram_probs(
        results[1000]["transformer"],
        ngram=2,
        uniform=False,
        backoff=True,
        addone=False,
    )

In [None]:
gram2[0].shape

In [None]:
for num_examples in (1000, 2500, 5000, 10000, 20000, 40000):
    results[num_examples]["3gram"] = get_ngram_probs(
        results[num_examples]["transformer"],
        ngram=3,
        uniform=False,
        backoff=True,
        addone=False,
    )
    results[num_examples]["2gram"] = get_ngram_probs(
        results[num_examples]["transformer"],
        ngram=2,
        uniform=False,
        backoff=True,
        addone=False,
    )
    folder = f"experiments/hiddens_{num_examples}/"
    for algo in ("2gram", "3gram"):
        probs = results[num_examples][algo]
        infos = results[num_examples]["transformer"]
        data = [
            {
                "probs": p,
                "dfa": d["dfa"],
                "input": d["input"],
                "vocab": d["vocab"],
            }
            for p, d in zip(probs, infos)
        ]
        os.makedirs(f"{folder}/{algo}", exist_ok=True)
        with open(f"{folder}/{algo}/probs.pkl", "wb") as f:
            pickle.dump(data, f)

In [None]:
# import pickle

# for num_examples in (2500,):
#     # save ngrams to a folder
#     folder = f"experiments/hiddens_{num_examples}/"
#     for algo in ("2gram", "3gram"):
#         # makedir
#         probs = results[num_examples][algo]
#         infos = results[num_examples]["transformer"]
#         data = [
#             {
#                 "probs": p,
#                 "dfa": d["dfa"],
#                 "input": d["input"],
#                 "vocab": d["vocab"],
#             }
#             for p, d in zip(probs, infos)
#         ]
#         os.makedirs(f"{folder}/{algo}", exist_ok=True)
#         with open(f"{folder}/{algo}/probs.pkl", "wb") as f:
#             pickle.dump(data, f)

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


def plot_l1_table(results, num_examples=20000, offset=0, max_len=None):
    models = [
        "dfa",
        "2gram",
        "3gram",
        "hyena",
        "linear_transformer",
        "transformer",
        "lstm",
        "transformer_2",
        "transformer_4",
    ]  # , "transformer_1", "transformer_2", "transformer_4", "transformer_8", "lstm",]

    models = set(models).intersection(list(results[num_examples].keys()))
    # L1 Differences
    l1_table = []
    dfa_probs = results[num_examples]["dfa"]
    for model1 in models:
        if model1 not in ("dfa",):
            model1_probs = get_model_probs(results[num_examples][model1], softmax="gram" not in model1)
        else:
            model1_probs = results[num_examples][model1]
        for model2 in models:
            if model2 not in ("dfa",):
                model2_probs = get_model_probs(results[num_examples][model2], softmax="gram" not in model2)
            else:
                model2_probs = results[num_examples][model2]
            value = (
                get_l1_loss(
                    model1_probs,
                    model2_probs,
                    dfa_probs,
                    offset=offset,
                    max_len=max_len,
                )
                / 2
            )

            model1 = (
                model1.replace("transformer", "TF")
                .replace("linear_", "L")
                .replace("_", "/")
            )
            model2 = (
                model2.replace("transformer", "TF")
                .replace("linear_", "L")
                .replace("_", "/")
            )

            l1_table.append([model1, model2, value])

    l1_df = pd.DataFrame(l1_table, columns=["model1", "model2", "value"])

    l1_df = l1_df.set_index(["model1"]).pivot(columns="model2", values="value")

    # fig size
    matplotlib.rcParams["figure.figsize"] = (6, 6)
    matplotlib.rcParams["font.size"] = 8
    matplotlib.rcParams["font.family"] = "serif"
    matplotlib.rcParams["figure.dpi"] = 300
    fix, ax = plt.subplots()
    sns.heatmap(
        l1_df, annot=True, ax=ax, cbar_kws={"orientation": "horizontal", "pad": 0.01}
    )
    # ax.set_title(f"L1 Loss (N={num_examples})")
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position("top")
    ax.set_xlabel(f"Model-1")
    ax.set_ylabel("Model-2")
    ax.xaxis.set_ticks_position("none")
    ax.yaxis.set_ticks_position("none")
    plt.xticks(rotation=30)

In [None]:
plot_l1_table(results, num_examples=2500, offset=0, max_len=100)

In [None]:
plot_l1_table(results, num_examples=40000, offset=0, max_len=100)

In [None]:
plot_l1_table(results, num_examples=5000)

In [None]:
hard_indices = {}
for num_examples in (1000, 2500, 5000, 10000, 20000, 40000):
    hard_indices[num_examples] = [
        ind
        for ind, r in enumerate(results[num_examples]["transformer"])
        if len(r["dfa"].dfa._states) > 7 and len(r["dfa"].dfa.alphabet) > 10
    ]

In [None]:
dfa_metrics = []
for metric in ("l1", "kl", "acc"):
    for num_examples in (1000, 2500, 5000, 10000, 20000, 40000):
        for model in results[num_examples].keys():
            if model != "dfa":
                if model not in ("3gram", "2gram"):
                    model_probs = get_model_probs(results[num_examples][model])
                else:
                    model_probs = results[num_examples][model]
                dfa_probs = results[num_examples]["dfa"]
                # # hard
                # model_probs = [model_probs[i] for i in hard_indices[num_examples]]
                # dfa_probs = [dfa_probs[i] for i in hard_indices[num_examples]]

                if metric == "l1":
                    value = get_l1_loss(model_probs, dfa_probs, dfa_probs) / 2
                elif metric == "kl":
                    value = get_kl(model_probs, dfa_probs)
                elif metric == "acc":
                    value = get_greedy_dfa_accuracy(model_probs, dfa_probs)
                dfa_metrics.append(
                    {
                        "model": model,
                        "metric": metric,
                        "value": value,
                        "num_examples": num_examples,
                    }
                )

In [None]:
# deep json to dataframe
df = pd.DataFrame(dfa_metrics)

In [None]:
# make neurips conference quality plots
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

plt.style.use("/raid/lingo/akyurek/mplstyle")
plt.rc("font", serif="Times")
plt.rc("text", usetex=False)
plt.rcParams["figure.dpi"] = 250
plt.rcParams["figure.facecolor"] = "white"

In [None]:
metric_names = {
    "kl": "KL",
    "l1": "TVD",
    "acc": "Accuracy",
}

model_names = {
    "transformer": "Transformer",
    "transformer_2": "Transformer (2 layers)",
    "transformer_1": "Transformer (1 layers)",
    "lstm": "LSTM",
    "hyena": "Hyena",
    "h3": "H3",
    "s4d": "S4D",
    "linear_transformer": "Linear Transformer",
    "rwkv": "RWKV",
    "retention": "RetNet",
}

In [None]:
# fig size
plt.rcParams.update({"figure.figsize": (6, 4)})
metric = "l1"

data = df[
    (df.metric == metric)
    & ~(
        df.model.isin(
            [
                "transformer_8",
                "transformer_4",
                "transformer_2",
                "transformer_1",
                "3gram",
                "2gram",
            ]
        )
    )
]
data = data.replace({"model": model_names})
ax = sns.lineplot(
    data=data,
    x="num_examples",
    y="value",
    hue="model",
    marker="o",
    linewidth=1.5,
)
ax.set_xlabel("# Training Examples")
ax.set_ylabel(metric_names[metric])
ax.set(xscale="log")
ax.set_xticks([1000, 2500, 5000, 10000, 20000, 40000])
ax.legend(title="Model")
ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
# show 2gram and 3gram as horizontal line with texts on them
ax.axhline(
    y=df[(df.metric == metric) & (df.model == "3gram")].value.mean(),
    color="black",
    linestyle="--",
    linewidth=0.5,
)
ax.text(
    1000,
    df[(df.metric == metric) & (df.model == "3gram")].value.mean() + 0.02,
    "3-gram",
    color="black",
)
ax.axhline(
    y=df[(df.metric == metric) & (df.model == "2gram")].value.mean(),
    color="gray",
    linestyle="--",
    linewidth=0.5,
)
ax.text(
    1000,
    df[(df.metric == metric) & (df.model == "2gram")].value.mean() + 0.03,
    "2-gram",
    color="gray",
)