In [95]:
import tokenizers
from tokenizers import Tokenizer
import os, time, re,json, copy
from collections import OrderedDict

In [96]:
base_tokenizer_fn = "/home/karyo/corpus/models/04-25_half/kr_57k/tokenizer.json"
plus_tokenizer_fn = "/home/karyo/corpus/models/04-25_half/ec_57k/tokenizer.json"
final_json_fn = "/home/karyo/corpus/models/union/tokenizer.json"

In [97]:
base = Tokenizer.from_file(base_tokenizer_fn)
plus = Tokenizer.from_file(plus_tokenizer_fn)
base_json = json.load(open(base_tokenizer_fn))
plus_json = json.load(open(plus_tokenizer_fn))

In [98]:
target = 102400
byte_level_alphabet = 256
whitespace_reservation = 24
def get_counts(tokenizer_json : dict) -> tuple:
    added_token_count = len(tokenizer_json['added_tokens'])
    merge_count = len(tokenizer_json['model']['merges'])
    vocab_size = len(tokenizer_json['model']['vocab'])
    no_merge_count = byte_level_alphabet + added_token_count
    assert vocab_size == no_merge_count + merge_count
    return (added_token_count, merge_count, vocab_size,no_merge_count)

In [99]:
# acquire token counts from tokenizers
base_added_token_count, base_merge_count, base_vocab_size, base_nomerge_count = get_counts(base_json)
plus_added_token_count, plus_merge_count, plus_vocab_size, plus_nomerge_count = get_counts(base_json)

# how much merges to get for final target
target_merge = target - byte_level_alphabet - whitespace_reservation - base_added_token_count

In [100]:
# we will operate on the values so deep copy
base_vocab = copy.deepcopy(base_json['model']['vocab'])
plus_vocab = copy.deepcopy(plus_json['model']['vocab'])

In [101]:
# vocabs are in {vocab : index} order
# make a copy as {index : vocab} order
inverse_base_vocab = {index:vocab for vocab,index in base_vocab.items()}
inverse_plus_vocab = {index:vocab for vocab,index in plus_vocab.items()}

# asserting if merge rules are aligned
for idx in range(base_merge_count):
    if base_json['model']['merges'][idx].replace(" ", "") \
    != inverse_base_vocab[idx + base_nomerge_count]:
        print("base", idx)
for idx in range(plus_merge_count):
    if plus_json['model']['merges'][idx].replace(" ", "") \
    != inverse_plus_vocab[idx + plus_nomerge_count]:
        print("plus",idx)

In [102]:
# go through each item to retrieve merge rules in order
base_vocab_merge = {}
for idx in range(base_merge_count):
    vocab = inverse_base_vocab[idx+base_nomerge_count]
    base_vocab_merge[vocab] = base_json['model']['merges'][idx]

In [103]:
plus_vocab_merge = {}
for idx in range(plus_merge_count):
    vocab = inverse_plus_vocab[idx+plus_nomerge_count]
    plus_vocab_merge[vocab] = plus_json['model']['merges'][idx]

In [104]:
union_vocab_merge = {}
union_keys = set(base_vocab_merge.keys()) | set(plus_vocab_merge.keys())
print(target_merge)
print(len(union_keys), target_merge - len(union_keys))

102112
101716 396


In [105]:
# build union of base, plus vocab
union_vocab_merge = {key : base_vocab_merge.get(key,plus_vocab_merge.get(key)) for key in union_keys}
for vocab, merge in union_vocab_merge.items():
    assert vocab == merge.replace(" ", "")

In [106]:
#sort union_vocab_merge  here

In [107]:
sorted_union_keys= sorted(union_vocab_merge.keys(), key=len,reverse=True)
ordered_union_vocab_merge = OrderedDict(union_vocab_merge.items())

In [108]:
# actual sorting done here. remove for original
# key=lambda x: x[0] alphabetical sorting
# key=lambda x: len(x) length sorting
union_vocab_merge = dict(sorted(ordered_union_vocab_merge.items(), key=lambda x: len(x[0]),reverse=False))


In [109]:
idx = base_nomerge_count
union_merges = []
inverse_union_vocab = {}
for token, merge in union_vocab_merge.items():
    token_idx = idx
    inverse_union_vocab[token_idx] = token
    union_merges.append(merge)
    idx += 1
union_vocab = {k:v for v,k in inverse_union_vocab.items()}
assert len(union_merges) == len(union_vocab)

In [110]:
union_vocab.keys()



In [111]:
# copy original base json to add to
final_json = copy.deepcopy(base_json)
#inverse_final_json = {index : vocab for vocab,index in final_json['model']['vocab'].items()}
final_json['model']['merges'] = union_merges
final_json['model']['vocab'] = final_json['model']['vocab'] | union_vocab

In [112]:
len(final_json['model']['vocab'])

101980

In [113]:
with open(final_json_fn, "w") as output_f:
    json.dump(final_json, output_f, indent=2, sort_keys=False, ensure_ascii=False)

In [114]:
final_tokenizer = Tokenizer.from_file(final_json_fn)

In [115]:
final_tokenizer.save(final_json_fn)