In [1]:
import regex as re

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [71]:
test_text = '''
low low low low low lower lower widest widest widest newest newest newest newest newest newest
'''.strip()

# re.findall(PAT, test_text)
pre_tokens = test_text.split()

In [19]:
# convert single word to bytes tuple
def word_to_byte_tuple(word):
    return tuple(bytes([b]) for b in word.encode('utf-8'))

print([word_to_byte_tuple(p) for p in pre_tokens])

[(b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'), (b'l', b'o', b'w'), (b'l', b'o', b'w'), (b'l', b'o', b'w'), (b'l', b'o', b'w'), (b'l', b'o', b'w'), (b'l', b'o', b'w', b'e', b'r'), (b'l', b'o', b'w', b'e', b'r'), (b'w', b'i', b'd', b'e', b's', b't'), (b'w', b'i', b'd', b'e', b's', b't'), (b'w', b'i', b'd', b'e', b's', b't'), (b'n', b'e', b'w', b'e', b's', b't'), (b'n', b'e', b'w', b'e', b's', b't'), (b'n', b'e', b'w', b'e', b's', b't'), (b'n', b'e', b'w', b'e', b's', b't'), (b'n', b'e', b'w', b'e', b's', b't'), (b'n', b'e', b'w', b'e', b's', b't')]


In [33]:
# get frequency table
from collections import Counter, defaultdict
freq_table = Counter(pre_tokens)

print(freq_table)

Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2, '测试': 1})


In [22]:
# combine together
def get_freq_table(token_list): 
    freq_table = Counter(token_list)
    freq_table_bytes = {word_to_byte_tuple(word): freq for word, freq in freq_table.items()}
    return freq_table_bytes

print(get_freq_table(pre_tokens))

{(b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'): 1, (b'l', b'o', b'w'): 5, (b'l', b'o', b'w', b'e', b'r'): 2, (b'w', b'i', b'd', b'e', b's', b't'): 3, (b'n', b'e', b'w', b'e', b's', b't'): 6}


In [42]:
def get_bp_freq_table(token_list):
    freq_table_bytes = get_freq_table(token_list)
    bp_freq_table = Counter()
    token_index = defaultdict(set)
    for b, f in freq_table_bytes.items():
        for i in range(len(b) - 1):
            bp_freq_table[(b[i], b[i + 1])] += f
            token_index[(b[i], b[i + 1])].add((b, i))
    return dict(bp_freq_table), token_index

bp_freq_table, token_index = get_bp_freq_table(pre_tokens)
print(bp_freq_table)
print(token_index)

{(b'\xe6', b'\xb5'): 1, (b'\xb5', b'\x8b'): 1, (b'\x8b', b'\xe8'): 1, (b'\xe8', b'\xaf'): 1, (b'\xaf', b'\x95'): 1, (b'l', b'o'): 7, (b'o', b'w'): 7, (b'w', b'e'): 8, (b'e', b'r'): 2, (b'w', b'i'): 3, (b'i', b'd'): 3, (b'd', b'e'): 3, (b'e', b's'): 9, (b's', b't'): 9, (b'n', b'e'): 6, (b'e', b'w'): 6}
defaultdict(<class 'set'>, {(b'\xe6', b'\xb5'): {((b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'), 0)}, (b'\xb5', b'\x8b'): {((b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'), 1)}, (b'\x8b', b'\xe8'): {((b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'), 2)}, (b'\xe8', b'\xaf'): {((b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'), 3)}, (b'\xaf', b'\x95'): {((b'\xe6', b'\xb5', b'\x8b', b'\xe8', b'\xaf', b'\x95'), 4)}, (b'l', b'o'): {((b'l', b'o', b'w', b'e', b'r'), 0), ((b'l', b'o', b'w'), 0)}, (b'o', b'w'): {((b'l', b'o', b'w', b'e', b'r'), 1), ((b'l', b'o', b'w'), 1)}, (b'w', b'e'): {((b'n', b'e', b'w', b'e', b's', b't'), 2), ((b'l', b'o', b'w', b'e', b'r'), 2)}, (b

In [37]:
def get_most_freq_pair(bp_freq_table):
    # get top 1, if tied, return lexicographically
    return max(bp_freq_table.items(), key=lambda kv: (kv[1], kv[0]))

print(get_most_freq_pair(bp_freq_table))

((b's', b't'), 9)


In [90]:
def one_merging_step(bp_freq_table, token_freq_table, token_index):
    (x, y), _ = get_most_freq_pair(bp_freq_table)
    merged = x + y

    occurrences = list(token_index.pop((x, y)))
    tok_pos = defaultdict(list)
    for tok, i in occurrences:
        tok_pos[tok].append(i)

    def bump(bp, freq_delta):
        if bp is None:
            return
        new_freq = bp_freq_table.get(bp, 0) + freq_delta
        if new_freq > 0:
            bp_freq_table[bp] = new_freq
        else:
            bp_freq_table.pop(bp)
        print(f"Bumping {bp} to {new_freq} with delta {freq_delta}")

    def remove_token_from_index(tok_to_remove):
        L = len(tok_to_remove)
        for i in range(L - 1):
            pair = (tok_to_remove[i], tok_to_remove[i+1])
            s = token_index.get(pair)
            if s is None:
                continue
            s.discard((tok_to_remove, i))
            if not s:
                token_index.pop(pair)
            # print(f"Removing {tok} from {tok_to_remove}")

    def add_token_to_index(tok_to_add, freq):
        L = len(tok_to_add)
        for i in range(L - 1):
            pair = (tok_to_add[i], tok_to_add[i+1])
            token_index.setdefault(pair, set()).add((tok_to_add, i))
            # bump(pair, freq)
        # print(f"Adding {tok_to_add, freq}")
        
        
    # print(tok_pos)
    for tok, pos in tok_pos.items():
        freq = token_freq_table[tok]
        pos.sort(reverse=True)

        cur = list(tok)
        # print(cur)
        merged_start_idx = float("inf")

        remove_token_from_index(tok)
        for i in pos:
            if i + 1 >= merged_start_idx:
                continue
            if i + 1 >= len(cur) or cur[i] != x or cur[i+1] != y:
                continue

            L_sym = cur[i-1] if i-1 >= 0 else None
            R_sym = cur[i+2] if i+2 < len(cur) else None
            # print(i, pos)
            # print(L_sym, R_sym, merged, merged_start_idx)
            if L_sym is not None:
                bump((L_sym, x), -freq)
            bump((x, y), -freq)
            if R_sym is not None:
                bump((y, R_sym), -freq)

            cur[i:i+2] = [merged]
            # print(cur)
            if L_sym is not None:
                bump((L_sym, merged), +freq)
            if R_sym is not None:
                bump((merged, R_sym), +freq)
                
            merged_start_idx = i

        new_tok = tuple(cur)
        if new_tok != tok:
            token_freq_table.pop(tok)
            token_freq_table[new_tok] = freq
            print(f"Adding {new_tok} with freq {freq}", tok)
        add_token_to_index(new_tok, freq)

    bp_freq_table.pop((x, y), None)

    return merged

freq_table = get_freq_table(pre_tokens)
bp_freq_table, token_index = get_bp_freq_table(pre_tokens)
print(one_merging_step(bp_freq_table, freq_table, token_index), token_index)

Bumping (b'e', b's') to 3 with delta -6
Bumping (b's', b't') to 3 with delta -6
Bumping (b'e', b'st') to 6 with delta 6
Adding (b'n', b'e', b'w', b'e', b'st') with freq 6 (b'n', b'e', b'w', b'e', b's', b't')
Bumping (b'e', b's') to 0 with delta -3
Bumping (b's', b't') to 0 with delta -3
Bumping (b'e', b'st') to 9 with delta 3
Adding (b'w', b'i', b'd', b'e', b'st') with freq 3 (b'w', b'i', b'd', b'e', b's', b't')
b'st' defaultdict(<class 'set'>, {(b'l', b'o'): {((b'l', b'o', b'w', b'e', b'r'), 0), ((b'l', b'o', b'w'), 0)}, (b'o', b'w'): {((b'l', b'o', b'w', b'e', b'r'), 1), ((b'l', b'o', b'w'), 1)}, (b'w', b'e'): {((b'l', b'o', b'w', b'e', b'r'), 2), ((b'n', b'e', b'w', b'e', b'st'), 2)}, (b'e', b'r'): {((b'l', b'o', b'w', b'e', b'r'), 3)}, (b'n', b'e'): {((b'n', b'e', b'w', b'e', b'st'), 0)}, (b'e', b'w'): {((b'n', b'e', b'w', b'e', b'st'), 1)}, (b'e', b'st'): {((b'w', b'i', b'd', b'e', b'st'), 3), ((b'n', b'e', b'w', b'e', b'st'), 3)}, (b'w', b'i'): {((b'w', b'i', b'd', b'e', b'st'), 

In [94]:
freq_table = get_freq_table(pre_tokens)
bp_freq_table, token_index = get_bp_freq_table(pre_tokens)
for _ in range(12):
    print(_, Counter(bp_freq_table), "==\n", one_merging_step(bp_freq_table, freq_table, token_index))

Bumping (b'e', b's') to 3 with delta -6
Bumping (b's', b't') to 3 with delta -6
Bumping (b'e', b'st') to 6 with delta 6
Adding (b'n', b'e', b'w', b'e', b'st') with freq 6 (b'n', b'e', b'w', b'e', b's', b't')
Bumping (b'e', b's') to 0 with delta -3
Bumping (b's', b't') to 0 with delta -3
Bumping (b'e', b'st') to 9 with delta 3
Adding (b'w', b'i', b'd', b'e', b'st') with freq 3 (b'w', b'i', b'd', b'e', b's', b't')
0 Counter({(b'e', b's'): 9, (b's', b't'): 9, (b'w', b'e'): 8, (b'l', b'o'): 7, (b'o', b'w'): 7, (b'n', b'e'): 6, (b'e', b'w'): 6, (b'w', b'i'): 3, (b'i', b'd'): 3, (b'd', b'e'): 3, (b'e', b'r'): 2}) ==
 b'st'
Bumping (b'd', b'e') to 0 with delta -3
Bumping (b'e', b'st') to 6 with delta -3
Bumping (b'd', b'est') to 3 with delta 3
Adding (b'w', b'i', b'd', b'est') with freq 3 (b'w', b'i', b'd', b'e', b'st')
Bumping (b'w', b'e') to 2 with delta -6
Bumping (b'e', b'st') to 0 with delta -6
Bumping (b'w', b'est') to 6 with delta 6
Adding (b'n', b'e', b'w', b'est') with freq 6 (b'n', 

In [2]:
max([(b'as', b't'), (b' .', b'..')])

(b'as', b't')

In [2]:
b'\x80'.decode("utf-8", errors="strict")

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

In [2]:
import json

with open("/Users/haotiansun/cs336/assignment1-basics/tokenizers/tinystories/vocab.json", "r") as f:
    vocab_str = json.load(f)

In [8]:
vocab_str

list(vocab_str.keys())[13].encode('utf-8')

b'\x0c'