In [92]:
import regex as rex

In [93]:
split_gpt4 = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

not_token = r"[A-Z]{2,}|[.,:-;\!\?]"

all_caps = r"[A-Z]{2,}"

pattern_general = rex.compile(split_gpt4)
pattern_all_caps = rex.compile(all_caps)
pattern_not_token = rex.compile(not_token)

rex.findall(pattern_general, rex.sub(pattern_not_token, "", "a KING of the Relm!?"))

['a', ' ', ' of', ' the', ' Relm']

In [119]:
books = ["tiny_shakespeare.txt", 
         "dracula.txt",
         "blake.txt",
         "pickwick.txt", "twist.txt", "hard times.txt", "dorrit.txt",
         "decline1.txt",
         "vanity.txt",
         "folly.txt",
         "white company.txt",
         "heights.txt",
         "secret agent.txt",
         "nonsense.txt",
         "Middlemarch.txt", "brother jacob.txt", "mill on the floss.txt", "the lifted veil.txt",
         "alice.txt", "hunting of the snark.txt", "Through the looking glass.txt", "a tangled.txt", "bruno.txt",
         "jude.txt", "mayor of castle.txt", "return of the native.txt", "Tess of the.txt", "mayor of castle.txt", "adam bede.txt",
         "Northanger Abbey.txt", "mansfield.txt", "emma.txt", "sense and.txt",
         "treasure island.txt", "kidnapped.txt"]

lines = []
token_counts = {" ": 0}
s_names = {"juliet": 0}

for book in books:
    with open(book, 'r', encoding='utf-8') as f:
        lines += f.readlines()

for line in lines:
    for name in rex.findall(pattern_all_caps, line):
        n = name.lower().strip()
        exists = s_names.get(n, 0)
        s_names[n] = exists + 1
    lts = rex.findall(pattern_general, rex.sub(pattern_not_token, "", line))
    for lt in lts:
        is_name = s_names.get(lt.lower().strip(), False)
        if is_name is True:
            continue
        exists = token_counts.get(lt, 0)
        token_counts[lt] = exists + 1

In [190]:
possible_names = []

for (t, v) in token_counts.items():
    tl = t.lower()
    if t != tl:
        add = token_counts.get(tl, 0)
        # token_counts[t] = v + add
        if add != 0:
            token_counts[tl] = v + add
        else:
            add = token_counts.get(" " + tl.strip(), 0)
            if add == 0:
                possible_names.append(t)

In [136]:
token_counts
    
top_tokens = sorted([(v, k) for (k, v) in token_counts.items()], reverse=True)

# for i, t in enumerate(top_tokens[:200]):
#     print(i, t)

In [137]:
token_counts
    
top_tokens = sorted([(v, k) for (k, v) in token_counts.items()], reverse=True)

# for i, t in enumerate(top_tokens[:200]):
#     print(i, t)

In [None]:
Casterbridge Clym Clare James Lucy Henchard Glegg Yeobright Lydgate Eustacia Dorothea Catherine Casaubon Bulstrode Rosamond Alice Ladislaw Middlemarch Farebrother Maggie Jude Tess Elizabeth Tulliver Lucetta

In [84]:
name_counts = {}
for name in possible_names:
    count = token_counts[name]
    n = name.strip()
    current = name_counts.get(n, 0)
    name_counts[n] = current + count

In [124]:
top_names = sorted([(v, k) for (k, v) in name_counts.items()], reverse=True)
top_names[:10]

[(1946, 'Mr'),
 (999, 'Lydgate'),
 (938, 'Dorothea'),
 (693, 'Mrs'),
 (659, '“I'),
 (645, 'Casaubon'),
 (568, 'Bulstrode'),
 (550, 'Fred'),
 (549, 'Rosamond'),
 (396, 'Alice')]

In [138]:
# s_names

In [77]:
roman = ['i', 'ii', 'iii', 'iv', 'v', 'vi', 'vii', 'viii', 'ix', 'x', 'xi']
titles = ['page', 'king', 'queen', 'prince', 'cardinal', 'bishop', 'archbishop', 'lord', 'lords', 
          'nurse', 'duke', 'duchess', 'earl', 'sir', 'lady', 'gentlemen', 'friar', 'mistress']
randomy = ['another', 'overdone', 'elbow', 'froth', 'measure', 'for', 'sly', 'all', 'men', 'ae', 'of', 'bushy', 'serving', 'scroop', 'serving']

ambiguous = ['richmond', 'paris', 'rivers', 'grey', 'gardener', 'green', 'surrey', 'york', 'oxford', 'dorset', 'hastings', 'northumberland', 'salisbury'
             'carlisle', 'warwick', 'exeter', 'westmoreland', 'somerset', 'gloucester', 'buckingham', 'derby', 'ely']

all_exceptions = roman + titles + randomy

for ex in all_exceptions:
    try:
        s_names.pop(ex)
    finally:
        continue

In [155]:
def top_pairs(ids, pairs=None):
    if pairs is None:
        pairs = {} 
    for pair in zip(ids, ids[1:]):
        pairs[pair] = pairs.get(pair, 0) + 1
    return pairs

def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    new_ids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode('utf-8', errors='replace')
    # s = replace_control_characters(s)
    return s

In [184]:
class Tokenizer:

    def __init__(self):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self.build_vocab() # int -> bytes

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        # input text preprocessing
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255

        merge_start = max(ids)

        print("vocab max", merge_start)

        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
        for i in range(num_merges):
            # count up the number of times every consecutive pair appears
            pairs = top_pairs(ids)
            # find the pair with the highest count
            pair = max(pairs, key=pairs.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurrences of pair in ids with idx
            ids = merge(ids, pair, idx)
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        # save class variables
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

    def encode(self, text):
        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)
        while len(ids) > 1:
            # find the pair with the lowest merge index
            pairs = top_pairs(ids)
            pair = min(pairs, key=lambda p: self.merges.get(p, float("inf")))
            
            if pair not in self.merges:
                break # nothing else can be merged anymore
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

    def decode(self, ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text

    def build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab


In [191]:
simple_text = ""

for i, t in top_tokens[:1000]:
    simple_text += t


In [192]:
tokenize = Tokenizer()

ids = tokenize.train(simple_text, 600)

len(tokenize.vocab)

vocab max 226


600

In [187]:
ids = [224,225,226]

tokenize.decode(ids)

'���'

In [193]:
save_tokens("v2_600", tokenize.merges, tokenize.vocab)

In [175]:
def load(model_file):
        """Inverse of save() but only for the model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        idx = 256
        with open(model_file, 'r', encoding="utf-8") as f:
            # read the version
            version = f.readline().strip()
            assert version == "minbpe v1"
            # # read the pattern
            # pattern = f.readline().strip()
            # read the special tokens
            # # num_special = int(f.readline().strip())
            # for _ in range(num_special):
            #     special, special_idx = f.readline().strip().split()
            #     special_tokens[special] = int(special_idx)
            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        return merges

In [176]:
merges = load("v1.model")

In [177]:
merges

{(101, 32): 256,
 (116, 32): 257,
 (100, 32): 258,
 (114, 32): 259,
 (110, 32): 260,
 (116, 104): 261,
 (111, 117): 262,
 (105, 110): 263,
 (115, 32): 264,
 (121, 32): 265,
 (101, 259): 266,
 (101, 97): 267,
 (103, 32): 268,
 (226, 128): 269,
 (108, 108): 270,
 (101, 114): 271,
 (111, 119): 272,
 (263, 268): 273,
 (84, 104): 274,
 (111, 109): 275,
 (103, 104): 276,
 (97, 110): 277,
 (101, 260): 278,
 (101, 258): 279,
 (101, 110): 280,
 (111, 32): 281,
 (111, 114): 282,
 (107, 32): 283,
 (270, 32): 284,
 (111, 110): 285,
 (104, 97): 286,
 (101, 108): 287,
 (104, 105): 288,
 (115, 257): 289,
 (115, 256): 290,
 (111, 260): 291,
 (261, 266): 292,
 (102, 32): 293,
 (118, 256): 294,
 (108, 258): 295,
 (97, 114): 296,
 (276, 257): 297,
 (115, 116): 298,
 (111, 111): 299,
 (97, 108): 300,
 (117, 114): 301,
 (269, 156): 302,
 (269, 153): 303,
 (97, 109): 304,
 (97, 265): 305,
 (114, 101): 306,
 (97, 260): 307,
 (272, 32): 308,
 (99, 104): 309,
 (98, 101): 310,
 (66, 101): 311,
 (119, 104): 312,

In [167]:
def save_tokens(file_prefix, merges, vocab):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        This is inspired (but not equivalent to!) sentencepiece's model saving:
        - model file is the critical one, intended for load()
        - vocab file is just a pretty printed version for human inspection only
        """
        # write the model: to be used in load() later
        model_file = file_prefix + ".model"
        with open(model_file, 'w') as f:
            # write the version, pattern and merges, that's all that's needed
            f.write("minbpe v1\n")
            # f.write(f"{self.pattern}\n")
            # # write the special tokens, first the number of them, then each one
            # f.write(f"{len(self.special_tokens)}\n")
            # for special, idx in self.special_tokens.items():
            #     f.write(f"{special} {idx}\n")
            # the merges dict
            for idx1, idx2 in merges:
                f.write(f"{idx1} {idx2}\n")
        # write the vocab: for the human to look at
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in vocab.items():
                # note: many tokens may be partial utf-8 sequences
                # and cannot be decoded into valid strings. Here we're using
                # errors='replace' to replace them with the replacement char �.
                # this also means that we couldn't possibly use .vocab in load()
                # because decoding in this way is a lossy operation!
                s = render_token(token)
                # find the children of this token, if any
                if idx in inverted_merges:
                    # if this token has children, render it nicely as a merge
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(vocab[idx0])
                    s1 = render_token(vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    # otherwise this is leaf token, just print it
                    # (this should just be the first 256 tokens, the bytes)
                    f.write(f"[{s}] {idx}\n")