In [None]:
import os
from os.path import join, isfile
from typing import List

import sys
sys.path.insert(0,'..')
from src.test_data import TEST_EXAMPLES


from ipywidgets import interact, Checkbox
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
import sentencepiece as spm

from plots import plot_histogram, compare_vocab, plot_overview, plot_timelines, plot_overview_data, plot_vocab_size

import numpy as np
import seaborn as sns
from itertools import product

from termcolor import colored

In [None]:
OUTPUT_DIR = "../output"

In [None]:
def get_models() -> List[str]:
    return [
        elem for elem in sorted(os.listdir(OUTPUT_DIR)) 
        if not elem.startswith(".") and not elem.startswith("evaluation")
    ]
    
models = get_models()
models

# 1. Show examples

In [None]:
"\N{ANGSTROM SIGN}", "\N{LATIN CAPITAL LETTER A WITH RING ABOVE}", "\u0041\u030A"

In [None]:
test_examples = TEST_EXAMPLES + [
    'Allmänna Allmänna',
    "<|endoftext|> test"
]

In [None]:
def decode_hack(_decoded_elementwise):
    """
    needs to be improved: 
    - should only be applied if add_prefix_space == True & add_whitespace_tokens == 24
    - should only change an element if the next element is a non-whitespace-element
    """
    return [
        elem[:-1] 
        if set(elem) == {' '} 
        else elem
        for elem in _decoded_elementwise 
    ]
    # return "".join(decoded_elementwise_hack)

def display(_example_decoded_per_token, show_linebreak = False):
    newline = "↩\n" if show_linebreak else "↩"
    example_decoded_per_token = [
            elem.replace("\n", newline).replace(" ", "-")
            for elem in _example_decoded_per_token
        ]
    
    COLORS = ["red", "blue"]
    for i, elem in enumerate(example_decoded_per_token):
        print(colored(elem, COLORS[i%len(COLORS)]), end="")
    print()
    print(f"> {len(example_decoded_per_token)} tokens")
    print()

def show_example_model(example, model, show_tokenization, verbose: bool = False):
    _id = model.split("_")[0]
    
    if isfile(join(OUTPUT_DIR, model, "tokenizer.json")):
        library = "HF"
        tokenizer_file = join(OUTPUT_DIR, model, "tokenizer.json")
        tokenizer = Tokenizer.from_file(tokenizer_file)
        tokenizer_fast = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
        encoding = tokenizer_fast.encode(example)
        example_encoded = tokenizer_fast.convert_ids_to_tokens(encoding)   
        example_decoded = tokenizer_fast.decode(encoding)    
        
    elif isfile(join(OUTPUT_DIR, model, "model.model")):
        library = "SP"
        tokenizer_file = join(OUTPUT_DIR, model, "model.model")
        sp = spm.SentencePieceProcessor(model_file=tokenizer_file)
        encoding = sp.encode(example, out_type=int)
        example_encoded = sp.encode(example, out_type=str)
        example_decoded = sp.decode(example_encoded)

    example_decoded_bytes = example_decoded.encode("utf-8")

    if library == "HF":
        example_decoded_elementwise = [tokenizer_fast.decode(elem) for elem in encoding]
    elif library == "SP":
        example_decoded_elementwise = list()
        idx_end = 0
        for i, token in enumerate(example_encoded):
            if i == 0 and token.startswith("▁"):
                _token = token[1:]
            elif i > 0 and token.startswith("▁"):
                _token = token.replace("▁", " ")
            else:
                _token = token
                
            if _token.startswith("<") and _token.endswith(">"):
                _token = sp.decode(_token)
            # print(i, token, _token)
            idx_start = example_decoded[idx_end:].find(_token) + idx_end
            idx_end = idx_start + len(_token)
            # print(idx_start, idx_end)
            # print()
            example_decoded_elementwise.append(example_decoded[idx_start: idx_end])
                
    example_decoded_elementwise_hack = decode_hack(example_decoded_elementwise)
    
    if verbose:
        print(f"============ {model}")
        print(f"example: '{example}'")
        # if library == "HF":
        #     print(f"\npre-tok: {tokenizer.pre_tokenizer.pre_tokenize_str(example)}")
        print(f"\nencoding: {encoding}")
        print(f"\nencoded: {example_encoded} --- {len(example_encoded)}")
        print(f"\ndecoded: '{example_decoded}'")
        # print(f"\ndecoded as bytes: {example_decoded_bytes}")
        print(f"\ndecoded elementwise: {example_decoded_elementwise}")
        print()
    
    if show_tokenization: 
        print("\nencoded:")
        display(example_encoded)
        print("\ndecoded:")
        display(example_decoded_elementwise, show_linebreak=True)
        # if library == "HF":
        #     print("\ndecoded + hack:")
        #     display(example_decoded_elementwise_hack, show_linebreak=True)
        print(f"\ndecoded = original: {example == example_decoded}")

In [None]:
@interact
def show_examples(example=test_examples, model=["ALL"] + models, show_tokenization=True, verbose=False):
    if model == "ALL":
        for model in sorted(models):
            show_example_model(example, model, show_tokenization, verbose)
    else:
        show_example_model(example, model, show_tokenization, verbose)

In [None]:
# STOP

In [None]:
# example
"ℌej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # ℌ, H --- ÅNGSTRÖM, Å, A+°

In [None]:
# NFC
"ℌej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # ℌ, H --- Å, Å, Å

In [None]:
# NFKD
"Hej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # H, H --- A+°, A+°, A+°

In [None]:
# NFKC
"Hej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # H, H --- Å, Å, Å

# 2. Subwords

### 2a. Subword Length Histograms

In [None]:
@interact
def show_histogram(model_1=models, model_2=[None] + models, xlim=20, ylim=15000):
    plot_histogram(model_1, model_2, xlim, ylim)

### 2b. Overlap

In [None]:
@interact
def show_compare_vocab(model_1=models, model_2=models, nr=30):
    v, ex1, ex2 = compare_vocab(model_1, model_2, 1000000, 1000000)
    print(v)
    print()
    print("=== only model 1 ===")
    print(ex1[:nr])
    print()
    print("=== only model 2 ===")
    print(ex2[:nr])

### 2c. Vocabulary Size & Subword Length Mean

In [None]:
@interact
def show_vocab_size(model=models):
    plot_vocab_size(model)

In [None]:
# STOP

# 3. Multilinguality

In [None]:
models_multilinguality = [model for model in models if model.count("_3") > 0]
if len(models_multilinguality):
    _core = list(set(["_".join(model.split("_")[1:-1]) for model in models_multilinguality if model.endswith("da")]))[0]
    core = _core#.split("-v")[0]
    print(core)
    # vocab = _core.split("-v")[-1]
    # print(vocab)
    models_multilinguality = [model for model in models_multilinguality if core in model]
    print(models_multilinguality)
    models_multilinguality.sort(key = lambda x: x.split("_3")[-1])
    models_multilinguality = {model.split("_3")[-1]: model for model in models_multilinguality}

models_multilinguality

In [None]:
if len(models_multilinguality):
    lang_complete = list(models_multilinguality.keys())
    lang_all = [l for l in lang_complete if l.startswith("all")]
    lang_pure = [l for l in lang_complete if not l.startswith("all")]

    models_complete = {k: models_multilinguality[k] for k in lang_complete}
    models_all = {k: models_multilinguality[k] for k in lang_all}
    models_pure = {k: models_multilinguality[k] for k in lang_pure}
else:
    lang_complete, lang_all, lang_pure, models, models_all, models_pure = [[]]*6
    
lang_complete, lang_all, lang_pure

In [None]:
models_pure

In [None]:
# overview_corpus(models_multilinguality)

### 3a. Time

In [None]:
if len(models_multilinguality):
    plot_overview_data(models_pure.values())

In [None]:
if len(models_multilinguality):
    plot_overview(models_pure.values())

### 3b. Intersection Matrix (Subword Length)

In [None]:
# get_intersection_matrix()

In [None]:
# get_intersection_matrix(0)

In [None]:
# get_intersection_matrix(10)

In [None]:
# get_intersection_matrix(10000)

### 3c. Evaluation #1: Vocabulary Intersection

In [None]:
def get_intersection(lang_1, lang_2, vocab_1, vocab_2):
    model_1 = models_multilinguality[lang_1]
    model_2 = models_multilinguality[lang_2]
    v, _, _ = compare_vocab(model_1, model_2, vocab_1, vocab_2)
    return v["intersection"]

In [None]:
if len(models_multilinguality):
    get_intersection('all-a1.0', 'da', 10000, 10000)

In [None]:
VOCAB = [10000, 20000, 30000, 40000, 51200, 64000, 80000, 96000, 112000, 128000]
VOCAB_1 = VOCAB
VOCAB_2 = VOCAB

# VOCAB_1 = [50000, 100000, 150000, 200000, 250000]
# VOCAB_2 = [100, 1000, 10000, 20000, 30000, 40000, 50000]

In [None]:
if len(models_multilinguality):
    intersections = {
        lang_1: {
            lang_2: {
                vocab_1: {
                    vocab_2: get_intersection(lang_1, lang_2, vocab_1, vocab_2)
                    for vocab_2 in VOCAB_2
                }
                for vocab_1 in VOCAB_1
            }
            for lang_2 in lang_complete
        }
        for lang_1 in lang_all
    }
else:
    intersections = None

# intersections

In [None]:
if len(models_multilinguality):
    timelines_abs = {
        lang_1: {
            vocab_2: {
                lang_2: 
                [intersections[lang_1][lang_2][vocab_1][vocab_2] for vocab_1 in VOCAB_1]
                for lang_2 in lang_pure
            }
            for vocab_2 in VOCAB_2
        }
        for lang_1 in lang_all
    }
else:
    timelines_abs = None
    
# timelines_abs

In [None]:
if len(models_multilinguality):
    timelines_rel = {
        lang_1: {
            vocab_2: {
                lang_2: 
                [intersections[lang_1][lang_2][vocab_1][vocab_2]/intersections[lang_1][lang_1][vocab_2][vocab_2] for vocab_1 in VOCAB_1]
                for lang_2 in lang_pure
            }
            for vocab_2 in VOCAB_2
        }
        for lang_1 in lang_all
    }
else:
    timelines_rel = None
    
# timelines_rel

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
@interact
def show_timelines(tokenizer=lang_all, vocab_size=VOCAB_2):
    if tokenizer is not None:
        lang_1 = tokenizer
        vocab_2 = vocab_size
        t_abs = timelines_abs[lang_1][vocab_2]
        t_rel = timelines_rel[lang_1][vocab_2]

        plot_timelines(
            VOCAB_1,
            vocab_2,
            [t_abs, t_rel],
            lang_pure, 
            ylim=[1.1*100000, 1.1],
            ylabel=["absolute", "relative"], 
            title=["Coverage of single-language tokenizer vocabulary"]*2,
        )
    else:
        print("> lang_all is []")

### 3d. Evaluation #2: unk_rate & closeness_to_character_level

In [None]:
def get_list_of_results():
    evaluation_dir = join(OUTPUT_DIR, "evaluation")
    results = [elem.split("results_")[-1].split(".json")[0] for elem in sorted(os.listdir(evaluation_dir))]
    return results

list_of_results = get_list_of_results()
list_of_results

In [None]:
import json

def read_results(_result):
    _results_path = join(OUTPUT_DIR, "evaluation", f"results_{_result}.json")
    with open(_results_path, "r") as file:
        r = json.load(file)
    return r

results = read_results('all-a1.0')
if 0:
    results

In [None]:
def retrieve_bf_cc_from_results(_results):
    models = list(set(_results.keys()))
    bfs = list(set([model.split("-bf")[1].split("-cc")[0] for model in models]))
    ccs = list(set([model.split("-cc")[1].split("-x")[0] for model in models]))
    return bfs, ccs

In [None]:
def retrieve_parameters_from_results(_bf, _cc, _results):
    models = list(set(_results.keys()))
    vocabs = sorted(list(set([int(model.split("-v")[1].split("_")[0]) for model in models])))
    vocabs_model = {
        vocab: [
            model 
            for model in models 
            if f"-bf{_bf}" in model
            and f"-cc{_cc}" in model
            and f"-v{vocab}_" in model
        ][0]
        for vocab in vocabs
    }
    files = list(_results[models[0]].keys())
    
    languages = [file.split("/")[-1].split(".json")[0].split("_")[1] for file in files]  # WORKS ONLY FOR 'wiki_??_t1p'!!! 
    languages_files = {k: v for k, v in zip(languages, files)}
    
    if 0:
        print(bfs)
        print(ccs)
        print()
        print(vocabs)
        print(vocabs_model)
        print(files)
        print(languages)
    
    return vocabs, vocabs_model, files, languages, languages_files
 
if 0:
    bfs, ccs = retrieve_bf_cc_from_results(results)
    vocabs, vocabs_model, files, languages, languages_files = retrieve_parameters_from_results(bfs[0], ccs[0], results)
    print(bfs)
    print(ccs)
    print()
    print(vocabs)
    print(vocabs_model)
    print(files)
    print(languages)
    print(languages_files)

In [None]:
def plot_evaluation_2(_unk_rate, _ctcl, _vocabs, _languages, _ymin, _ymax):
    import matplotlib.pyplot as plt
    colors = {"da": "r", "en": "g", "is": "b", "no": "purple", "sv": "orange"}
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    for language in _languages:
        ax[0].plot(_vocabs, _unk_rate[language], linestyle=None, marker="s", color=colors[language], label=language)
        ax[1].plot(_vocabs, _ctcl[language], linestyle=None, marker="s", color=colors[language], label=language)
    for i in range(2):
        ax[i].set_xlim([0, 150000])
        ax[i].set_ylim([_ymin, _ymax])
        ax[i].legend()
    ax[0].set_title("unknown rate (lower = better)")
    ax[1].set_title("closeness to character level (lower = better)")

In [None]:
@interact
def show_evaluation_2(result=list_of_results):
    r = read_results(result)
    bfs, ccs = retrieve_bf_cc_from_results(r)
    
    @interact
    def show_evaluation_2_detail(bf=bfs, cc=ccs, ymin=0.0, ymax=1.0):
        vocabs, vocabs_models, files, languages, languages_files = retrieve_parameters_from_results(bf, cc, r)
        # print(bf, cc, result)
        results_filtered = {k: v for k, v in r.items() if f"-bf{bf}-cc{cc}" in k}
        # print()
        # print(results_filtered)
        # print()
        
        unk_rate = {
            language: [
                results_filtered[vocabs_models[vocab]][languages_files[language]]["unk_rate"]
                for vocab in vocabs
            ]
            for language in languages
        }
        closeness_to_character_level = {
            language: [
                results_filtered[vocabs_models[vocab]][languages_files[language]]["closeness_to_character_level"]
                for vocab in vocabs
            ]
            for language in languages
        }
        # print(unk_rate)
        # print(closeness_to_character_level)

        plot_evaluation_2(unk_rate, closeness_to_character_level, vocabs, languages, ymin, ymax)