In [3]:
import pickle
import regex as re
from collections import Counter

with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
pickle_tokenizer = 'tokenizer-01.pkl'

GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

GPT4_SPECIAL_TOKENS = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}

In [4]:
class GPT4Tokenizer():
    def __init__(self, special_tokens=None, pattern=None):
        self.vocab = {idx: bytes([idx]) for idx in range(256)} # idx: bytes
        self.merges = {} # (p0, p1): idx
        self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
        self.special_tokens = GPT4_SPECIAL_TOKENS if special_tokens is None else special_tokens # str: idx
        self.inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}

    @staticmethod
    def _get_stats(ids):
        """returns the number of occurences of each pair of byte codes in our encoded text"""
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
    
        return counts # (p0, p1): number of occurences

    @staticmethod
    def _merge(ids, pair, idx):
        """in the list of ints (ids), replace all consecutive occurences of pair with the new token idx"""
        newids = []
        i = 0
        while i < len(ids):
            # if we are not at the very last position AND the pair matches, replace it
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)
                i += 2
            else:
                newids.append(ids[i])
                i += 1
        return newids

    def train(self, text, vocab_size, verbose=False):
        """adds to the current vocab up to a desired final size, based on merges from a given text"""
        start_vocab = len(self.vocab)
        num_merges = vocab_size - start_vocab
        
        
        chunks = re.findall(self.pattern, text)
        chunk_ids = []
        for chunk in chunks:
            chunk_ids.append(self.encode_ordinary(chunk))

        for i in range(num_merges):
            stats = Counter({})
            for chunk in chunk_ids:
                stats += Counter(self._get_stats(chunk))
            stats = dict(stats)
            
            if not stats:
                print(f"No more possible chunks, current/max vocab = {len(self.vocab)}")
                break
                
            pair = max(stats, key=stats.get)
            idx = start_vocab + i
                
            for i, chunk in enumerate(chunk_ids):    
                chunk_ids[i] = self._merge(chunk, pair, idx)
            
            self.merges[pair] = idx
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]

            if verbose:
                print(f"merging {pair[0], self.decode([pair[0]])} and {pair[1], self.decode([pair[1]])} into a new token {idx, self.decode([idx])}")
        print(f"Training Complete, chars added:{len(self.vocab)-start_vocab}, new vocab size:{len(self.vocab)}")

    def decode(self, ids, skip_special_tokens=True):
        """given ids(list of ints), return python string"""
        tokens = []
        for idx in ids:
            if idx in self.inverse_special_tokens:
                if skip_special_tokens:
                    continue
                else:
                    tokens.append(self.inverse_special_tokens[idx].encode('utf-8'))
            else:
                tokens.append(self.vocab[idx])
        tokens = b"".join(tokens)
        text = tokens.decode('utf-8', errors='replace')
        return text

    def encode_ordinary(self, text):
        """Encoding that ignores any special tokens."""
        tokens = list(text.encode('utf-8'))
        while len(tokens) >= 2:
            stats = self._get_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            tokens = self._merge(tokens, pair, idx)
        return tokens

    def encode(self, text, allowed_special="none_raise"):
        """return our bytes given a python string"""
        # decode the user desire w.r.t. handling of special tokens
        special = None
        if allowed_special == "all":
            special = self.special_tokens
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special, set):
            special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        if not special:
            # shortcut: if no special tokens, just use the ordinary encoding
            return self.encode_ordinary(text)
            
        # we handle special tokens by splitting the text
        # based on the occurrence of any exact match with any of the special tokens
        # we can use re.split for this. note that surrounding the pattern with ()
        # makes it into a capturing group, so the special tokens will be included
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern, text)
        # now all the special characters are separated from the rest of the text
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for part in special_chunks:
            if part in special:
                # this is a special token, encode it separately as a special case
                ids.append(special[part])
            else:
                # this is an ordinary sequence, encode it normally
                ids.extend(self.encode_ordinary(part))
        return ids

In [5]:
t = GPT4Tokenizer()

"""print('loading params...')
with open(pickle_tokenizer, 'rb') as f:
    t = pickle.load(f)
print('loaded succesfuly')"""

"print('loading params...')\nwith open(pickle_tokenizer, 'rb') as f:\n    t = pickle.load(f)\nprint('loaded succesfuly')"

In [6]:
t.train(text, 300, True)

merging (32, ' ') and (116, 't') into a new token (256, ' t')
merging (104, 'h') and (101, 'e') into a new token (257, 'he')
merging (32, ' ') and (97, 'a') into a new token (258, ' a')
merging (256, ' t') and (257, 'he') into a new token (259, ' the')
merging (114, 'r') and (101, 'e') into a new token (260, 're')
merging (105, 'i') and (110, 'n') into a new token (261, 'in')
merging (32, ' ') and (115, 's') into a new token (262, ' s')
merging (32, ' ') and (119, 'w') into a new token (263, ' w')
merging (110, 'n') and (100, 'd') into a new token (264, 'nd')
merging (32, ' ') and (111, 'o') into a new token (265, ' o')
merging (101, 'e') and (100, 'd') into a new token (266, 'ed')
merging (32, ' ') and (98, 'b') into a new token (267, ' b')
merging (111, 'o') and (117, 'u') into a new token (268, 'ou')
merging (10, '\n') and (10, '\n') into a new token (269, '\n\n')
merging (104, 'h') and (97, 'a') into a new token (270, 'ha')
merging (258, ' a') and (264, 'nd') into a new token (271,

In [10]:
t.decode(t.encode("hello world!!!? (ì•ˆë…•í•˜<|fim_middle|>ì„¸ìš”!) lol123 ðŸ˜‰<|endoftext|>", allowed_special="all"))

'hello world!!!? (ì•ˆë…•í•˜ì„¸ìš”!) lol123 ðŸ˜‰'

'hello world!!!? (ì•ˆë…•í•˜ì„¸ìš”!) lol123 ðŸ˜‰'

In [None]:
"""with open(pickle_tokenizer, 'wb') as f:
    pickle.dump(t, f)
print("tokenizer saved")"""