In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json

from nsc.api import tables
from nsc.api import utils
from nsc.utils import io

from spell_checking import BENCHMARK_DIR

In [5]:
def get_table_data(file_path: str, task: str) -> tuple:
    data = []
    runtimes = []
    with open(file_path, "r") as inf:
        json_data = json.load(inf)
        for model in sorted(json_data):
            runtime, file_size = json_data[model]
            kbs = (file_size / 1000) / runtime
            data.append([task, model, f"{runtime:.1f}", f"{kbs:.1f}"])
            runtimes.append(runtime)
    return data, runtimes

In [20]:
headers = [["Task", "Model/Pipeline", "Runtime in s", "kB/s"]]
data = []
horizontal_lines = []

tr_data = get_table_data("runtime_stats/tr_stats.json", "TR")
tr_data, _ = list(zip(*sorted(list(zip(*tr_data)), key=lambda e: e[1], reverse=True)))
data.extend(tr_data)
horizontal_lines.extend([False] * (len(tr_data) - 1) + [True])

sed_words_data = get_table_data("runtime_stats/sed_words.json", "SEDS/SEDW")
sed_words_data, _ = list(zip(*sorted(list(zip(*sed_words_data)), key=lambda e: e[1], reverse=True)))
data.extend(sed_words_data)
horizontal_lines.extend([False] * (len(sed_words_data) - 1) + [True])

sec_data, sec_runtimes = get_table_data("runtime_stats/sec_nmt_stats.json", "SEC")

sec_with_sed_data, sec_with_sed_runtimes = get_table_data("runtime_stats/sec_with_sed_stats.json", r"SEDW $\rightarrow$ SEC")
sec_data.extend(sec_with_sed_data)
sec_runtimes.extend(sec_with_sed_runtimes)

sec_neuspell_data, sec_neuspell_runtimes = get_table_data("runtime_stats/sec_neuspell.json", "SEC")
sec_data.extend(sec_neuspell_data)
sec_runtimes.extend(sec_neuspell_runtimes)

sec_data, _ = list(zip(*sorted(list(zip(sec_data, sec_runtimes)), key=lambda e: e[1], reverse=True)))
data.extend(sec_data)
horizontal_lines.extend([False] * (len(sec_data) - 1) + [True])

sec_with_tr_data, sec_with_tr_runtimes = get_table_data("runtime_stats/sec_with_tr_stats.json", r"TR \& SEC")

sec_tok_plus_data, sec_tok_plus_runtimes = get_table_data("runtime_stats/sec_tok_plus_stats.json", r"TR $\rightarrow$ SEDW $\rightarrow$ SEC")
sec_with_tr_data.extend(sec_tok_plus_data)
sec_with_tr_runtimes.extend(sec_tok_plus_runtimes)

sec_tr_pipe_data, sec_tr_pipe_runtimes = get_table_data("runtime_stats/tr_pipeline_stats.json", r"TR $\rightarrow$ SEDW $\rightarrow$ SEC")
sec_with_tr_data.extend(sec_tr_pipe_data)
sec_with_tr_runtimes.extend(sec_tr_pipe_runtimes)

sec_with_tr_data, _ = list(zip(*sorted(list(zip(sec_with_tr_data, sec_with_tr_runtimes)), key=lambda e: e[1], reverse=True)))

data.extend(sec_with_tr_data)
horizontal_lines.extend([False] * (len(sec_with_tr_data) - 1) + [True])

latex_table = tables.generate_table(
    headers,
    data,
    horizontal_lines=horizontal_lines,
    fmt="latex"
)
utils.save_text_file(os.path.join(BENCHMARK_DIR, "test", "runtime_tables", "runtimes_generated.tex"), [latex_table])
print(tables.generate_table(
    headers,
    data,
    horizontal_lines=horizontal_lines,
    fmt="markdown"
))

| Task | Model/Pipeline | Runtime in s | kB/s |
| :-- | --: | --: | --: |
| TR | eo large | 7.9 | 29.9 |
| TR | eo medium | 6.1 | 38.5 |
| TR | eo small | 5.8 | 40.4 |
| SEDS/SEDW | tokenization repair\textsuperscript{+}/tokenization repair\textsuperscript{++} | 13.1 | 18.0 |
| SEDS/SEDW | gnn\textsuperscript{+} | 9.7 | 24.3 |
| SEDS/SEDW | gnn | 9.2 | 25.7 |
| SEDS/SEDW | transformer\textsuperscript{+} | 5.6 | 42.3 |
| SEDS/SEDW | transformer | 4.8 | 49.5 |
| SEC | transformer | 73.5 | 3.2 |
| SEDW $\rightarrow$ SEC | transformer\textsuperscript{+} $\rightarrow$ transformer | 47.2 | 5.0 |
| SEDW $\rightarrow$ SEC | gnn\textsuperscript{+} $\rightarrow$ transformer | 46.4 | 5.1 |
| SEC | transformer word | 37.2 | 6.4 |
| SEC | neuspell bert | 21.8 | 10.8 |
| SEDW $\rightarrow$ SEC | gnn\textsuperscript{+} $\rightarrow$ transformer word | 13.7 | 17.3 |
| SEDW $\rightarrow$ SEC | transformer\textsuperscript{+} $\rightarrow$ transformer word | 13.5 | 17.5 |
| TR \& SEC | transformer with t