In [None]:
import os
from os.path import join
from typing import List

import sys
sys.path.insert(0,'..')
from test_data import TEST_EXAMPLES


from ipywidgets import interact, Checkbox
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

from plots import plot_histogram, compare_vocab, plot_overview, plot_timelines, plot_overview_data, plot_vocab_size


import numpy as np
import seaborn as sns
from itertools import product

from termcolor import colored

In [None]:
OUTPUT_DIR = "../output"

In [None]:
def get_models() -> List[str]:
    return [elem for elem in sorted(os.listdir(OUTPUT_DIR)) if not elem.startswith(".")]
    
models = get_models()
models

# 1. Show examples

In [None]:
"\N{ANGSTROM SIGN}", "\N{LATIN CAPITAL LETTER A WITH RING ABOVE}", "\u0041\u030A"

In [None]:
test_examples = TEST_EXAMPLES + [
    'Allmänna Allmänna',
    "<|endoftext|> test"
]

In [None]:
def show_example_model(example, model, show_tokenization):
    _id = model.split("_")[0]
    tokenizer_file = join(OUTPUT_DIR, model, "tokenizer.json")
    tokenizer = Tokenizer.from_file(tokenizer_file)
    tokenizer_fast = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
    encoding = tokenizer_fast.encode(example)
    print(f"============ {model}")
    print(f"example: '{example}'")
    print(f"pre-tok: {tokenizer.pre_tokenizer.pre_tokenize_str(example)}")
    # print(encoding)
    example_encoded = tokenizer_fast.convert_ids_to_tokens(encoding)
    print(f"encoded: {example_encoded} --- {len(example_encoded)}")
    example_decoded = tokenizer_fast.decode(encoding)
    print(f"decoded: '{example_decoded}'")
    example_decoded_bytes = example_decoded.encode("utf-8")
    print(f"decoded as bytes: {example_decoded_bytes}")
    print()
    example_decoded_per_token = [tokenizer_fast.decode(elem).replace("\n", "↩\n").replace(" ", "-") for elem in encoding]
    
    if show_tokenization:
        COLORS = ["red", "blue"] # "green"] # , "blue", "magenta", "cyan"]
        for i, elem in enumerate(example_decoded_per_token):
            print(colored(elem, COLORS[i%len(COLORS)]), end="")
        print()
        print(f"> {len(example_decoded_per_token)} tokens")
        print()

In [None]:
@interact
def show_examples(example=test_examples, model=["ALL"] + models, show_tokenization=False):
    if model == "ALL":
        for model in sorted(models):
            show_example_model(example, model, show_tokenization)
    else:
        show_example_model(example, model, show_tokenization)
        
    

In [None]:
# STOP

In [None]:
# example
"ℌej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # ℌ, H --- ÅNGSTRÖM, Å, A+°

In [None]:
# NFC
"ℌej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # ℌ, H --- Å, Å, Å

In [None]:
# NFKD
"Hej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # H, H --- A+°, A+°, A+°

In [None]:
# NFKC
"Hej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # H, H --- Å, Å, Å

# 2. Subwords

### 2a. Subword Length Histograms

In [None]:
@interact
def show_histogram(model_1=models, model_2=[None] + models, xlim=20, ylim=15000):
    plot_histogram(model_1, model_2, xlim, ylim)

### 2b. Overlap

In [None]:
@interact
def show_compare_vocab(model_1=models, model_2=models, nr=5):
    v, ex1, ex2 = compare_vocab(model_1, model_2)
    print(v)
    print()
    print("=== only model 1 ===")
    print(ex1[:nr])
    print()
    print("=== only model 2 ===")
    print(ex2[:nr])

### 2c. Vocabulary Size & Subword Length Mean

In [None]:
@interact
def show_vocab_size(model=models):
    plot_vocab_size(model)

# 3. Multilinguality

In [None]:
models_multilinguality = [model for model in models if model.count("_3") > 0]
core = list(set(["_".join(model.split("_")[1:-1]) for model in models_multilinguality if model.endswith("da")]))[0]
models_multilinguality = [model for model in models_multilinguality if core in model]
models_multilinguality.sort(key = lambda x: x.split("_3")[-1])
models_multilinguality = {model.split("_3")[-1]: model for model in models_multilinguality}
models_multilinguality

In [None]:
lang = list(models_multilinguality.keys())
lang_pure = [l for l in lang if not l.startswith("all")]
lang, lang_pure

In [None]:
# overview_corpus(models_multilinguality)

### 3a. Time

In [None]:
plot_overview_data(models_multilinguality.values())

In [None]:
plot_overview(models_multilinguality.values())

### 3b. Intersection Matrix (Subword Length)

In [None]:
# get_intersection_matrix()

In [None]:
# get_intersection_matrix(0)

In [None]:
# get_intersection_matrix(10)

In [None]:
# get_intersection_matrix(10000)

### 3c. Intersection Timeline (Subword Length)

In [None]:
def get_intersection(lang_1, lang_2, vocab_1, vocab_2):
    model_1 = models_multilinguality[lang_1]
    model_2 = models_multilinguality[lang_2]
    v, _, _ = compare_vocab(model_1, model_2, vocab_1, vocab_2)
    return v["intersection"]

In [None]:
get_intersection('all', 'da', 10000, 10000)

In [None]:
VOCAB = [100, 1000, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000]
# VOCAB = [10000, 100000]
VOCAB_1 = VOCAB
VOCAB_2 = VOCAB

In [None]:
intersections = {
    lang_1: {
        lang_2: {
            vocab_1: {
                vocab_2: get_intersection(lang_1, lang_2, vocab_1, vocab_2)
                for vocab_2 in VOCAB_2
            }
            for vocab_1 in VOCAB_1
        }
        for lang_2 in lang
    }
    for lang_1 in ["all"]
}

# intersections

In [None]:
timelines_abs = {
    lang_1: {
        vocab_2: {
            lang_2: 
            [intersections[lang_1][lang_2][vocab_1][vocab_2] for vocab_1 in VOCAB_1]
            for lang_2 in lang_pure
        }
        for vocab_2 in VOCAB_2
    }
    for lang_1 in ["all"]
}
# timelines_abs

In [None]:
timelines_rel = {
    lang_1: {
        vocab_2: {
            lang_2: 
            [intersections[lang_1][lang_2][vocab_1][vocab_2]/intersections[lang_1][lang_1][vocab_2][vocab_2] for vocab_1 in VOCAB_1]
            for lang_2 in lang_pure
        }
        for vocab_2 in VOCAB_2
    }
    for lang_1 in ["all"]
}
# timelines_rel

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
@interact
def show_timelines(tokenizer=["all"], vocab_size=VOCAB_2):
    lang_1 = tokenizer
    vocab_2 = vocab_size
    t_abs = timelines_abs[lang_1][vocab_2]
    t_rel = timelines_rel[lang_1][vocab_2]
    
    plot_timelines(
        VOCAB_1,
        vocab_2,
        [t_abs, t_rel],
        lang, 
        ylim=[1.1*100000, 1.1],
        ylabel=["absolute", "relative"], 
        title=["Coverage of single-language tokenizer vocabulary"]*2,
    )