In [None]:
import os
from os.path import join
from typing import List

import sys
sys.path.insert(0,'..')
from test_data import TEST_EXAMPLES


from ipywidgets import interact, Checkbox
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

from plots import plot_histogram, compare_vocab, plot_overview, plot_timelines, plot_overview_data, plot_vocab_size


import numpy as np
import seaborn as sns
from itertools import product

from termcolor import colored

In [None]:
OUTPUT_DIR = "../output"

In [None]:
def get_models() -> List[str]:
    return sorted(os.listdir(OUTPUT_DIR))
    
models = get_models()
models

# 1. Show examples

In [None]:
"\N{ANGSTROM SIGN}", "\N{LATIN CAPITAL LETTER A WITH RING ABOVE}", "\u0041\u030A"

In [None]:
test_examples = TEST_EXAMPLES + [
    'Allmänna Allmänna',
    "<|endoftext|> test"
]

In [None]:
def show_example_model(example, model, show_tokenization):
    _id = model.split("_")[0]
    tokenizer_file = join(OUTPUT_DIR, model, "tokenizer.json")
    tokenizer = Tokenizer.from_file(tokenizer_file)
    tokenizer_fast = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
    encoding = tokenizer_fast.encode(example)
    print(f"============ {model}")
    print(f"example: '{example}'")
    print(f"pre-tok: {tokenizer.pre_tokenizer.pre_tokenize_str(example)}")
    # print(encoding)
    example_encoded = tokenizer_fast.convert_ids_to_tokens(encoding)
    print(f"encoded: {example_encoded} --- {len(example_encoded)}")
    example_decoded = tokenizer_fast.decode(encoding)
    print(f"decoded: '{example_decoded}'")
    example_decoded_bytes = example_decoded.encode("utf-8")
    print(f"decoded as bytes: {example_decoded_bytes}")
    print()
    example_decoded_per_token = [tokenizer_fast.decode(elem).replace("\n", "↩\n").replace(" ", "-") for elem in encoding]
    
    if show_tokenization:
        COLORS = ["red", "blue"] # "green"] # , "blue", "magenta", "cyan"]
        for i, elem in enumerate(example_decoded_per_token):
            print(colored(elem, COLORS[i%len(COLORS)]), end="")
        print()
        print(f"> {len(example_decoded_per_token)} tokens")
        print()

In [None]:
@interact
def show_examples(example=test_examples, model=["ALL"] + models, show_tokenization=False):
    if model == "ALL":
        for model in sorted(models):
            show_example_model(example, model, show_tokenization)
    else:
        show_example_model(example, model, show_tokenization)
        
    

In [None]:
# STOP

In [None]:
# example
"ℌej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # ℌ, H --- ÅNGSTRÖM, Å, A+°

In [None]:
# NFC
"ℌej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # ℌ, H --- Å, Å, Å

In [None]:
# NFKD
"Hej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # H, H --- A+°, A+°, A+°

In [None]:
# NFKC
"Hej Hej --- TVÅ TVÅ TVÅ".encode("utf-8")  # H, H --- Å, Å, Å

# 2. Subwords

### 2a. Subword Length Histograms

In [None]:
@interact
def show_histogram(model_1=models, model_2=[None] + models, xlim=20, ylim=15000):
    plot_histogram(model_1, model_2, xlim, ylim)

### 2b. Overlap

In [None]:
@interact
def show_compare_vocab(model_1=models, model_2=models, nr=5):
    v, ex1, ex2 = compare_vocab(model_1, model_2)
    print(v)
    print()
    print("=== only model 1 ===")
    print(ex1[:nr])
    print()
    print("=== only model 2 ===")
    print(ex2[:nr])

### 2c. Vocabulary Size & Subword Length Mean

In [None]:
@interact
def show_vocab_size(model=models):
    plot_vocab_size(model)

# 3. Multilinguality

In [None]:
models_multilinguality = [model for model in models if model.count("_3") > 0]
models_multilinguality = [model for model in models_multilinguality if not "all" in model]
models_multilinguality.sort(key = lambda x: x.split("_3")[-1])
models_multilinguality

In [None]:
# overview_corpus(models_multilinguality)

### 3a. Time

In [None]:
plot_overview_data(models_multilinguality)

In [None]:
plot_overview(models_multilinguality)

### 3b. Intersection Matrix (Subword Length)

In [None]:
def get_intersection_matrix(subword_length_threshold = None, normalize = True, plot = True):

    lang = [model.split("_3")[-1] for model in models_multilinguality]
    N = len(lang)
    intersection_matrix = np.zeros([N, N])

    for i, j in product(range(N), range(N)):
        model_1 = models_multilinguality[i]
        model_2 = models_multilinguality[j]
        lang_1 = model_1.split("_3")[-1]
        lang_2 = model_2.split("_3")[-1]
        v, _, _ = compare_vocab(model_1, model_2, subword_length_threshold)
        # print(lang_1, lang_2, v["intersection"])
        intersection_matrix[i, j] = v["intersection"]
    
    # print(lang)
    # print(intersection_matrix)
    if normalize:
        for i, j in product(range(N), range(N)):
            if i != j:
                intersection_matrix[i, j] = intersection_matrix[i, j] / intersection_matrix[i, i]
        for i in range(N):
            intersection_matrix[i, i] = 1.0
    # print(intersection_matrix)
    
    if plot:
        ax = sns.heatmap(intersection_matrix, 
                         xticklabels=lang,
                         yticklabels=lang,
                         cmap="binary",
                         vmin=0,
                         annot=True,
        )
    else:
        return intersection_matrix

In [None]:
# get_intersection_matrix()

In [None]:
# get_intersection_matrix(1)

In [None]:
# get_intersection_matrix(2)

In [None]:
# get_intersection_matrix(3)

### 3c. Intersection Timeline (Subword Length)

In [None]:
MAX = 2 # 20

In [None]:
lang = [model.split("_3")[-1] for model in models_multilinguality]
lang

In [None]:
imatrices = {i: get_intersection_matrix(i, normalize=True, plot=False) for i in range(1, MAX+1)}
# imatrices

In [None]:
# p_all = p(x|all) = 1st column
p_all = {i: imatrices[i][0] if len(imatrices[i]) > 0 else None for i in range(1, MAX+1)} 
# p_all

In [None]:
timelines_all = {
    lang[l]: [p_all[i][l] for i in range(1, MAX+1)]
    for l in range(len(lang))   
}
# timelines_all
    

In [None]:
# p_x = p(all|x) = 1st row
p_x = {i: imatrices[i][:,0] if len(imatrices[i]) > 0 else None for i in range(1, MAX+1)}
# p_x

In [None]:
timelines_x = {
    lang[l]: [p_x[i][l] for i in range(1, MAX+1)]
    for l in range(len(lang))   
}
# timelines_x
    

In [None]:
plot_timelines(
    [timelines_all, timelines_x], 
    lang, 
    ylabel=["p(x|all)", "p(all|x)"], 
    title=[
        "Of the subwords in the ALL vocab, how many are in the X vocab?", 
        "Of the subwords in the X vocab, how many are in the ALL vocab?"
    ]
)