In [1]:
import tokenizers
from tokenizers import Tokenizer
import os, time, re,json, copy
from collections import OrderedDict

In [2]:
base_tokenizer_fn = "/home/karyo/corpus/models/04-25_half/kr_57k/tokenizer.json"
plus_tokenizer_fn = "/home/karyo/corpus/models/04-25_half/ec_57k/tokenizer.json"
final_json_fn = "/home/karyo/corpus/models/union/tokenizer.json"

In [3]:
base = Tokenizer.from_file(base_tokenizer_fn)
plus = Tokenizer.from_file(plus_tokenizer_fn)
base_json = json.load(open(base_tokenizer_fn))
plus_json = json.load(open(plus_tokenizer_fn))

In [4]:
target = 102400
byte_level_alphabet = 256
whitespace_reservation = 24
def get_counts(tokenizer_json : dict) -> tuple:
    added_token_count = len(tokenizer_json['added_tokens'])
    merge_count = len(tokenizer_json['model']['merges'])
    vocab_size = len(tokenizer_json['model']['vocab'])
    no_merge_count = byte_level_alphabet + added_token_count
    assert vocab_size == no_merge_count + merge_count
    return (added_token_count, merge_count, vocab_size,no_merge_count)

In [5]:
# acquire token counts from tokenizers
base_added_token_count, base_merge_count, base_vocab_size, base_nomerge_count = get_counts(base_json)
plus_added_token_count, plus_merge_count, plus_vocab_size, plus_nomerge_count = get_counts(base_json)

# how much merges to get for final target
target_merge = target - byte_level_alphabet - whitespace_reservation - base_added_token_count

In [6]:
# we will operate on the values so deep copy
base_vocab = copy.deepcopy(base_json['model']['vocab'])
plus_vocab = copy.deepcopy(plus_json['model']['vocab'])

In [7]:
# vocabs are in {vocab : index} order
# make a copy as {index : vocab} order
inverse_base_vocab = {v:k for k,v in base_vocab.items()}
inverse_plus_vocab = {v:k for k,v in plus_vocab.items()}

# asserting if merge rules are aligned
for idx in range(base_merge_count):
    if base_json['model']['merges'][idx].replace(" ", "") \
    != inverse_base_vocab[idx + base_nomerge_count]:
        print("base", idx)
for idx in range(plus_merge_count):
    if plus_json['model']['merges'][idx].replace(" ", "") \
    != inverse_plus_vocab[idx + plus_nomerge_count]:
        print("plus",idx)

In [8]:
# go through each item to 
base_vocab_merge = {}
for idx in range(base_merge_count):
    vocab = inverse_base_vocab[idx+base_nomerge_count]
    base_vocab_merge[vocab] = base_json['model']['merges'][idx]
len(base_vocab_merge)

57592

In [9]:
plus_vocab_merge = {}
for idx in range(plus_merge_count):
    vocab = inverse_plus_vocab[idx+plus_nomerge_count]
    plus_vocab_merge[vocab] = plus_json['model']['merges'][idx]
len(plus_vocab_merge)

57592

In [10]:
union_vocab_merge = {}
union_keys = set(base_vocab_merge.keys()) | set(plus_vocab_merge.keys())
print(target_merge)
print(len(union_keys), target_merge - len(union_keys))

102112
101716 396


In [11]:
# with 51200 * 2 11969 deficit add 6144 to each
# with 57344 * 2(2^13 *7) 1284 deficit add 512 each
# with 57856 * 2(2^9 * 113) 396 deficit -> fill with buffer tokens

In [12]:
union_vocab_merge = {key : base_vocab_merge.get(key,plus_vocab_merge.get(key)) for key in union_keys}
for key, value in union_vocab_merge.items():
    assert key == value.replace(" ", "")

In [13]:
union_json = copy.deepcopy(base_json)
union_vocab_vk = {v : k for k,v in union_json['model']['vocab'].items()}

In [14]:
union_vocab_vk[base_nomerge_count]

'Ġì'

In [15]:
#sort union merges, union_vocab_vk here

In [16]:
idx = 0
union_merges = []
union_vocab_vk = {}
for token, merge in union_vocab_merge.items():
    token_idx = idx + base_nomerge_count
    union_vocab_vk[token_idx] = token
    union_merges.append(merge)
    idx += 1
union_vocab = {k:v for v,k in union_vocab_vk.items()}
assert len(union_merges) == len(union_vocab_vk)

In [17]:
union_json['model']['merges'] = union_merges
len(union_json['model']['merges'])

101716

In [18]:
union_vocab_index_token = {index:token for token,index in union_vocab.items()}

In [19]:
base_vocab_index_token = {index:token for token,index in base_json['model']['vocab'].items()}

In [20]:
final_vocab_index_token = base_vocab_index_token | union_vocab_index_token

In [21]:
final_vocab = {token:index for index,token in final_vocab_index_token.items()}

In [22]:
union_json['model']['vocab'] = final_vocab
union_json['model']['merges'] = union_merges

In [26]:
union_json

{'version': '1.0',
 'truncation': None,
 'padding': None,
 'added_tokens': [{'id': 0,
   'content': '<s>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 1,
   'content': '</s>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 2,
   'content': '<|usr|>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 3,
   'content': '<|pad|>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 4,
   'content': '<|sys|>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 5,
   'content': '<|unk|>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 6,
   'content': '<|sep|>',
   'single_wor

In [23]:
with open(final_json_fn, "w") as output_f:
    json.dump(union_json, output_f, indent=2, sort_keys=False, ensure_ascii=False)

In [24]:
final_tokenizer = Tokenizer.from_file(final_json_fn)

In [25]:
final_tokenizer.save(final_json_fn)