In [None]:
from wandb import Api
from collections import defaultdict
from tqdm.auto import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import chain

api = Api(api_key="API_KEY")

In [None]:
def get_best_metric_from_run(run, metric: str = "valid/accuracy"):
    if run.state == "finished":
      return run.summary["max_accuracy"]
    else:
      print(f"Skipping unfinished run {run.id}")
      return None

sweep_main_fig = api.sweep("entity/project/sweep")

all_runs = list(chain(sweep_main_fig.runs))
metrics = defaultdict(list)

for run in tqdm(all_runs):
    accuracy = get_best_metric_from_run(run)
    if accuracy is not None:
      metrics['model'].append(run.config['model.sequence_mixer.name'])
      metrics['lr'].append(run.config['learning_rate'])
      metrics['d_model'].append(run.config['model']['d_model'])
      metrics['seq_len'].append(run.config['data']['input_seq_len'])
      metrics['accuracy'].append(accuracy)

In [None]:
data = pd.DataFrame.from_records(metrics)
df_with_acc = data.groupby(["model", "d_model", "seq_len", "lr"]).mean().reset_index()
model_results = df_with_acc.sort_values(["model", "d_model", "seq_len", "accuracy"]).drop_duplicates(["model", "d_model", "seq_len"], keep='last')

df_with_acc_std = data.groupby(["model", "d_model", "seq_len", "lr"]).std().reset_index()
model_results_std = df_with_acc_std.iloc[model_results.index]

In [None]:
sns.set_style("whitegrid")

fig, axs = plt.subplots(nrows=1, ncols=5, figsize=(12, 2.5), sharey=True)
%config InlineBackend.figure_format='retina'


axs[0].title.set_text('Sequence Length: 128')
axs[0].set_ylabel('Accuracy')

axs[1].title.set_text('Sequence Length: 256')
axs[2].title.set_text('Sequence Length: 512')
axs[3].title.set_text('Sequence Length: 1024')
axs[4].title.set_text('Sequence Length: 2048')


for i in range(5):
  axs[i].set_xticks([0, 1, 2, 3])
  axs[i].set_xticklabels([64, 128, 256, 512])

markers = ["o", "s", "^", "+", "*", "d", 'H']
model_to_name = {
    "based": "Based",
    "rebased": "ReBased",
    "attention": "Attention",
    "conv_attention": "СonvAttention",
    "rwkv": "RWKV",
    "mamba": "Mamba",
    "conv_rwkv": "ConvRWKV"
}

for i, seq_len in enumerate([128, 256, 512, 1024, 2048]):
  for j, model in enumerate(["rebased", "based", "rwkv", "mamba", "conv_attention"]):
    line = model_results[(model_results.seq_len == seq_len) & (model_results.model == model)].accuracy.tolist()
    std = model_results_std[(model_results.seq_len == seq_len) & (model_results.model == model)].accuracy.tolist()
    up_std = [s if s+l < 1 else 1 - l for l, s in zip(line, std)]
    low_std = [s if l-s>0 else l for l, s in zip(line, std)]

    x = list(range(len(line)))
    axs[i].errorbar(x, line, yerr=(low_std, up_std), marker=markers[j], label=model_to_name[model], capsize=5)


fig.tight_layout()
handles, labels = plt.gca().get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=5, bbox_to_anchor=(0.5, -0.15))
fig.text(0.5, 0.0, 'Model Dim', ha='center')
plt.savefig("main_fig.pdf", bbox_inches="tight")
plt.show()