# GPT-2 
---

In [1]:
import tensorflow as tf
import random

In [2]:
tf.enable_eager_execution()

In [3]:
def nprint(*args):
    print(*args, end='\n\n-----------------\n')

---

## Byte Pair Encoding Utilities



In [10]:
import os
import json
import regex as re
from functools import lru_cache 

---
### Bytes to unicode

Interesting video [here](https://www.youtube.com/watch?v=MijmeoH9LT4), recommended on the [Python documentation page](https://docs.python.org/3/howto/unicode.html).

In [11]:
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe (Byte Pair Encoding) codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    (N.B.: <UNK> is used in many datasets as a placeholder for 'unknown' (e.g. words).)
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    
    # ord: returns integer corresponding to Unicode character
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]

    n = 0
                  # 256
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1

    # chr: returns string corresponding to Unicode integer code 
    # (of such and such character)
    # replace integer codes by their characters
    cs = [chr(n) for n in cs]

    # return the dict { 33: '!', 34: '"', ... }
    return dict(zip(bs, cs))

Quick recap: `ord` gives you the unicode point number, `char` the character for the according number.

In [31]:
print(ord("!"), 'is', chr(33))

33 is !


The selected ranges in the function:

In [58]:
range1 = list(range(ord("!"), ord("~")+1))
chars1 = [chr(c) for c in range1]
range2 = list(range(ord("¡"), ord("¬")+1))
chars2 = [chr(c) for c in range2]
range3 =list(range(ord("®"), ord("ÿ")+1))
chars3 = [chr(c) for c in range3]

print([(x,y) for x, y in zip(range1, chars1)])
print()
print([(x,y) for x, y in zip(range2, chars2)])
print()
print([(x,y) for x, y in zip(range3, chars3)])
print()

[(33, '!'), (34, '"'), (35, '#'), (36, '$'), (37, '%'), (38, '&'), (39, "'"), (40, '('), (41, ')'), (42, '*'), (43, '+'), (44, ','), (45, '-'), (46, '.'), (47, '/'), (48, '0'), (49, '1'), (50, '2'), (51, '3'), (52, '4'), (53, '5'), (54, '6'), (55, '7'), (56, '8'), (57, '9'), (58, ':'), (59, ';'), (60, '<'), (61, '='), (62, '>'), (63, '?'), (64, '@'), (65, 'A'), (66, 'B'), (67, 'C'), (68, 'D'), (69, 'E'), (70, 'F'), (71, 'G'), (72, 'H'), (73, 'I'), (74, 'J'), (75, 'K'), (76, 'L'), (77, 'M'), (78, 'N'), (79, 'O'), (80, 'P'), (81, 'Q'), (82, 'R'), (83, 'S'), (84, 'T'), (85, 'U'), (86, 'V'), (87, 'W'), (88, 'X'), (89, 'Y'), (90, 'Z'), (91, '['), (92, '\\'), (93, ']'), (94, '^'), (95, '_'), (96, '`'), (97, 'a'), (98, 'b'), (99, 'c'), (100, 'd'), (101, 'e'), (102, 'f'), (103, 'g'), (104, 'h'), (105, 'i'), (106, 'j'), (107, 'k'), (108, 'l'), (109, 'm'), (110, 'n'), (111, 'o'), (112, 'p'), (113, 'q'), (114, 'r'), (115, 's'), (116, 't'), (117, 'u'), (118, 'v'), (119, 'w'), (120, 'x'), (121, 'y'

The idea is to avoid the empty spaces (and other types of char beyond...).

In [48]:
for i in range(126, 162):
    print(chr(i), end=', ')
print()
for i in range(172, 175):
    print(chr(i), end=', ')
print()
for i in range(255, 300):
    print(chr(i), end=', ')

~, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ,  , ¡, 
¬, ­, ®, 
ÿ, Ā, ā, Ă, ă, Ą, ą, Ć, ć, Ĉ, ĉ, Ċ, ċ, Č, č, Ď, ď, Đ, đ, Ē, ē, Ĕ, ĕ, Ė, ė, Ę, ę, Ě, ě, Ĝ, ĝ, Ğ, ğ, Ġ, ġ, Ģ, ģ, Ĥ, ĥ, Ħ, ħ, Ĩ, ĩ, Ī, ī, 

Just for fun:

In [59]:
bigrange = list(range(0,3000))
bigrchars = [chr(x) for x in bigrange]
print([(x,y) for x,y in zip(bigrange, bigrchars)])

[(0, '\x00'), (1, '\x01'), (2, '\x02'), (3, '\x03'), (4, '\x04'), (5, '\x05'), (6, '\x06'), (7, '\x07'), (8, '\x08'), (9, '\t'), (10, '\n'), (11, '\x0b'), (12, '\x0c'), (13, '\r'), (14, '\x0e'), (15, '\x0f'), (16, '\x10'), (17, '\x11'), (18, '\x12'), (19, '\x13'), (20, '\x14'), (21, '\x15'), (22, '\x16'), (23, '\x17'), (24, '\x18'), (25, '\x19'), (26, '\x1a'), (27, '\x1b'), (28, '\x1c'), (29, '\x1d'), (30, '\x1e'), (31, '\x1f'), (32, ' '), (33, '!'), (34, '"'), (35, '#'), (36, '$'), (37, '%'), (38, '&'), (39, "'"), (40, '('), (41, ')'), (42, '*'), (43, '+'), (44, ','), (45, '-'), (46, '.'), (47, '/'), (48, '0'), (49, '1'), (50, '2'), (51, '3'), (52, '4'), (53, '5'), (54, '6'), (55, '7'), (56, '8'), (57, '9'), (58, ':'), (59, ';'), (60, '<'), (61, '='), (62, '>'), (63, '?'), (64, '@'), (65, 'A'), (66, 'B'), (67, 'C'), (68, 'D'), (69, 'E'), (70, 'F'), (71, 'G'), (72, 'H'), (73, 'I'), (74, 'J'), (75, 'K'), (76, 'L'), (77, 'M'), (78, 'N'), (79, 'O'), (80, 'P'), (81, 'Q'), (82, 'R'), (83, '

---
Now the actual result (a dictionary):

In [25]:
btu = bytes_to_unicode()
print(btu)

{33: '!', 34: '"', 35: '#', 36: '$', 37: '%', 38: '&', 39: "'", 40: '(', 41: ')', 42: '*', 43: '+', 44: ',', 45: '-', 46: '.', 47: '/', 48: '0', 49: '1', 50: '2', 51: '3', 52: '4', 53: '5', 54: '6', 55: '7', 56: '8', 57: '9', 58: ':', 59: ';', 60: '<', 61: '=', 62: '>', 63: '?', 64: '@', 65: 'A', 66: 'B', 67: 'C', 68: 'D', 69: 'E', 70: 'F', 71: 'G', 72: 'H', 73: 'I', 74: 'J', 75: 'K', 76: 'L', 77: 'M', 78: 'N', 79: 'O', 80: 'P', 81: 'Q', 82: 'R', 83: 'S', 84: 'T', 85: 'U', 86: 'V', 87: 'W', 88: 'X', 89: 'Y', 90: 'Z', 91: '[', 92: '\\', 93: ']', 94: '^', 95: '_', 96: '`', 97: 'a', 98: 'b', 99: 'c', 100: 'd', 101: 'e', 102: 'f', 103: 'g', 104: 'h', 105: 'i', 106: 'j', 107: 'k', 108: 'l', 109: 'm', 110: 'n', 111: 'o', 112: 'p', 113: 'q', 114: 'r', 115: 's', 116: 't', 117: 'u', 118: 'v', 119: 'w', 120: 'x', 121: 'y', 122: 'z', 123: '{', 124: '|', 125: '}', 126: '~', 161: '¡', 162: '¢', 163: '£', 164: '¤', 165: '¥', 166: '¦', 167: '§', 168: '¨', 169: '©', 170: 'ª', 171: '«', 172: '¬', 174: 

---
### Get pairs

Disseminate a word into the set of its ordered pairs.

In [12]:
def get_pairs(word):
    """
    Return set of symbol pairs in a word.
    Word argument is given as a tuple of symbols (symbols being variable-length strings).

    Thus, the word 'word' is represented as;
    {('o', 'r'), ('r', 'd'), ('w', 'o')}
    """
    pairs = set()
    
    prev_char = word[0] # init at first char
    # for each char, create a pair with the following one, then shift by one char
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
        
    return pairs

In [65]:
word = 'antidisestablishmentarianism'
print(get_pairs(word))

{('t', 'a'), ('s', 'h'), ('n', 't'), ('t', 'i'), ('a', 'b'), ('i', 'a'), ('d', 'i'), ('i', 's'), ('i', 'd'), ('s', 'm'), ('e', 'n'), ('h', 'm'), ('a', 'n'), ('s', 'e'), ('m', 'e'), ('a', 'r'), ('r', 'i'), ('n', 'i'), ('l', 'i'), ('e', 's'), ('s', 't'), ('b', 'l')}


However, will be used like so (why is that?)

In [66]:
word = tuple(word)

print()
print(word)

print()
print(get_pairs(word))


('a', 'n', 't', 'i', 'd', 'i', 's', 'e', 's', 't', 'a', 'b', 'l', 'i', 's', 'h', 'm', 'e', 'n', 't', 'a', 'r', 'i', 'a', 'n', 'i', 's', 'm')

{('t', 'a'), ('s', 'h'), ('n', 't'), ('t', 'i'), ('a', 'b'), ('i', 'a'), ('d', 'i'), ('i', 's'), ('i', 'd'), ('s', 'm'), ('e', 'n'), ('h', 'm'), ('a', 'n'), ('s', 'e'), ('m', 'e'), ('a', 'r'), ('r', 'i'), ('n', 'i'), ('l', 'i'), ('e', 's'), ('s', 't'), ('b', 'l')}


A time thing??

In [68]:
%%timeit
word = 'antidisestablishmentarianism'
get_pairs(word)

4.25 µs ± 82.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [69]:
%%timeit
word = 'antidisestablishmentarianism'
word = tuple(word)
get_pairs(word)

4.74 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


Obviously not. **The answer is** (as usual, read the docs): 

> Word argument is given as a tuple of symbols (symbols being variable-length strings).  

The loop inside the function would separate all letters. How does the `tuple()` operation does it, however?

---
N.B: because this is a set, words that are composed of the same pairs become the same. More specifically, any *shift right or left* ([cyclic permutation](https://en.wikipedia.org/wiki/Cyclic_permutation)) of the letters gives the same result. Quite unproblematic for NLP, but still worth knowing. 

In [22]:
print(get_pairs('cabc'))
print(get_pairs('abca'))
print(get_pairs('bcab'))

{('a', 'b'), ('c', 'a'), ('b', 'c')}
{('a', 'b'), ('c', 'a'), ('b', 'c')}
{('a', 'b'), ('c', 'a'), ('b', 'c')}


---

### Get Encoder

Read from json, bpe & create Encoder object.

In [15]:
def get_encoder(model_name):
    # get the vocabulary as a json file 
    with open(os.path.join('../models', model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    # get the complete vocabulary as txt file
    with open(os.path.join('../models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()

    # translates a string format with x y on each line to [(x,y),...]
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
                                                                                # skip the first line
                                                                                # that has the version,
                                                                                # skip the last element
                                                                                # of split, which will
                                                                                # be empty
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

In [85]:
model_name = '117M'

The encoder (json):

In [86]:
with open(os.path.join('../models', model_name, 'encoder.json'), 'r') as f:
        json117 = json.load(f)

In [102]:
print('Size of dict:', len(json117))
print('A random element:', random.choice(list(json117.items())))

Size of dict: 50257
A random element: ('ĠAPIs', 23113)


The bpe_encoder:

In [89]:
with open(os.path.join('../models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
    bpe117 = f.read()

In [101]:
print('Length of string:', len(bpe117))
print()
print(bpe117[:25])
print()
print(bpe117.split('\n')[:10])
print(bpe117.split('\n')[-10:])

Length of string: 420572

#version: 0.2
Ġ t
Ġ a
h e

['#version: 0.2', 'Ġ t', 'Ġ a', 'h e', 'i n', 'r e', 'o n', 'Ġt he', 'e r', 'Ġ s']
['Ġ( /', 'âĢ¦ ."', 'Com par', 'Ġampl ification', 'om inated', 'Ġreg ress', 'ĠColl ider', 'Ġinform ants', 'Ġg azed', '']


The bpe merges:

In [283]:
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe117.split('\n')[1:-1]]
print('First ten elements:', bpe_merges[:10], sep='\n')
print()
print('Last ten elements:', bpe_merges[-10:], sep='\n')

First ten elements:
[('Ġ', 't'), ('Ġ', 'a'), ('h', 'e'), ('i', 'n'), ('r', 'e'), ('o', 'n'), ('Ġt', 'he'), ('e', 'r'), ('Ġ', 's'), ('a', 't')]

Last ten elements:
[('Comm', 'ission'), ('Ġ(', '/'), ('âĢ¦', '."'), ('Com', 'par'), ('Ġampl', 'ification'), ('om', 'inated'), ('Ġreg', 'ress'), ('ĠColl', 'ider'), ('Ġinform', 'ants'), ('Ġg', 'azed')]


---
### The Encoder class

In [13]:
class Encoder:
    """
    Attributes: 
    - encoder/decoder (dicts)
    - errors: option for bytearray.decode()
    - byte_encoder/decoder (dicts)
    - bpe_ranks
    """
                                            # errors='replace'
                                            # an option for the bytearray() conversion function used below.
                                            # cf Python doc: Replace with a suitable replacement marker; 
                                            # Python will use the official U+FFFD REPLACEMENT CHARACTER 
                                            # for the built-in codecs on decoding, and ‘?’ on encoding.  
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}            # simply reversing from {k:v} to {v:k}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()                          # our look-up table function
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} # reversing again, for bytes
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))  # { x0: 0, x1: 1, ...}
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
                            # Regexes:
                            # contractions
                            # words: one or more of any letter (\p{L}), preceded by optional space
                            # numbers: (\p{N}), preceded by optional space
                            # no code: NOT a space followed by one letter & one or more of any number (code)
                            #          preceded by optional space, all this one or more times
                            # no single space: one or more spaces not followed non-whitespace, negative lookahead: (?!\S) 
                            # one or more spaces ok
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", flags=re.IGNORECASE)
                                                                                                                # adding ignorecase, as mentioned above

    def bpe(self, token):

        # don't do the work twice, save words on the go
        if token in self.cache:
            return self.cache[token]

        word = tuple(token)     # turn token to char tuple
        pairs = get_pairs(word) # get all char pairs: ('w','o','r','d') > { ('w', 'o'), ('o', 'r'), ('r', 'd') }

        # if word was only one symbol?
        if not pairs:
            return token

        while True:
                                                                            
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
                                                                            # float('inf'): It acts as an unbounded upper value for 
            if bigram not in self.bpe_ranks:                                # comparison. This is useful for finding lowest 
                break                                                       # values for something. 
                                                                            # https://stackoverflow.com/a/34264749
            first, second = bigram                                          
            new_word = []                                                   

            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)   # returns index of searched element (first), starting at i
                    new_word.extend(word[i:j]) # append items from iterable to the end of the array
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):

        bpe_tokens = []                                 
        
        # for each token found by our regex (words, numbers, more than one space, punctuation)
        for token in re.findall(self.pat, text):     

            # encode to utf-8 (char > int), then encode to byte, then join in a string
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))

            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))

        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

A bit of dissection... We can reuse the `117M` encoder we uploaded above.

In [275]:
encoder = json117

In [276]:
enc_self_encoder = encoder
enc_self_decoder = {v:k for k,v in enc_self_encoder.items()}            # simply reversing from {k:v} to {v:k}
enc_self_byte_encoder = bytes_to_unicode()                          # our look-up table function
enc_self_byte_decoder = {v:k for k, v in enc_self_byte_encoder.items()} # reversing again, for bytes
enc_self_bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))  # { x0: 0, x1: 1, ...}
enc_self_cache = {}

### Regexes

In [238]:
# contractions
# words: one or more of any letter (\p{L}), preceded by optional space
# numbers: (\p{N}), preceded by optional space
# punctuation: not a space, a letter or a number, one or more times, preceded by optional space
# no single space: one or more spaces not followed non-whitespace, negative lookahead: (?!\S) 
# one or more spaces ok
enc_self_pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", flags=re.IGNORECASE)

As a test:

In [152]:
reg_test = "Won't? I'm I'll I'd numbers 9012 also multiple spaces ' ?        no code a9877   b002 x0 b9d8"
print(re.findall(enc_self_pat, reg_test))

['Won', "'t", '?', ' I', "'m", ' I', "'ll", ' I', "'d", ' numbers', ' 9012', ' also', ' multiple', ' spaces', " '", ' ?', '       ', ' no', ' code', ' a', '9877', '  ', ' b', '002', ' x', '0', ' b', '9', 'd', '8']


We can also separate the regexes, to see what they do:

In [153]:
enc_self_reg1 = re.compile(r""" ?[^\s\p{L}\p{N}]+""", flags=re.IGNORECASE)
enc_self_reg2 = re.compile(r"""\s+(?!\S)""", flags=re.IGNORECASE)
enc_self_reg3 = re.compile(r"""\s+""", flags=re.IGNORECASE) # all spaces
print(re.findall(enc_self_reg1, reg_test))
print(re.findall(enc_self_reg2, reg_test))
print(re.findall(enc_self_reg3, reg_test))

["'", '?', "'", "'", "'", " '", ' ?']
['       ', '  ']
[' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '        ', ' ', ' ', '   ', ' ', ' ']


---
### The bpe function

In [327]:
def enc_bpe(token):

    # don't do the work twice, save words on the go
    if token in enc_self_cache:
        return enc_self_cache[token]

    word = tuple(token)     # turn token to char tuple
    pairs = get_pairs(word) # get all char pairs: ('w','o','r','d') > { ('w', 'o'), ('o', 'r'), ('r', 'd') }

    # if word was only one symbol?
    if not pairs:
        return token

    while True:

        bigram = min(pairs, key = lambda pair: enc_self_bpe_ranks.get(pair, float('inf')))
                                                                        # float('inf'): It acts as an unbounded upper value for 
        if bigram not in enc_self_bpe_ranks:                            # comparison. This is useful for finding lowest 
            break                                                       # values for something. 
                                                                        # https://stackoverflow.com/a/34264749
        first, second = bigram                                          
        new_word = []                                                   

        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)   # returns index of searched element (first), starting at i
                new_word.extend(word[i:j]) # append items from iterable to the end of the array
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)
    
    # returns a string
    word = ' '.join(word)
    enc_self_cache[token] = word
    return word

Dissect!

In [245]:
txt = " theory"
token = re.findall(enc_self_pat, txt)[0]  
token = ''.join(enc_self_byte_encoder[b] for b in token.encode('utf-8'))
print(token)

Ġtheory


In [305]:
tktpl = tuple(token)
print(tktpl)
tkpairs = get_pairs(tktpl)
print(tkpairs)

('Ġ', 't', 'h', 'e', 'o', 'r', 'y')
{('h', 'e'), ('Ġ', 't'), ('o', 'r'), ('t', 'h'), ('e', 'o'), ('r', 'y')}


As a reminder, the `enc_self_bpe_ranks` is a `dict` containing pairs of elements.

In [279]:
print('First ten elements:', list(enc_self_bpe_ranks.items())[:10], sep='\n')
print()
index = random.randint(0,len(enc_self_bpe_ranks)-11)
print('Random ten elements:', list(enc_self_bpe_ranks.items())[index:index+10], sep='\n')

First ten elements:
[(('Ġ', 't'), 0), (('Ġ', 'a'), 1), (('h', 'e'), 2), (('i', 'n'), 3), (('r', 'e'), 4), (('o', 'n'), 5), (('Ġt', 'he'), 6), (('e', 'r'), 7), (('Ġ', 's'), 8), (('a', 't'), 9)]

Random ten elements:
[(('ĠP', 'ru'), 35255), (('p', 'un'), 35256), (('ĠL', 'OL'), 35257), (('))', '))'), 35258), (('ĠL', 'iqu'), 35259), (('ĠS', 'AS'), 35260), (('Ġsty', 'ling'), 35261), (('Ġpunish', 'ments'), 35262), (('Ġnum', 'b'), 35263), (('Ġasc', 'ertain'), 35264)]


Method:
- take the `min()` pair according to its ranking in `enc_self_bpe_ranks`;
- for that, use a `lambda` function, that gets the appropriate ranking number, and if the pair is not found, return `float('inf')`, namely, don't select it.

Recap:   
`dict.get()` [documentation](https://docs.python.org/3/library/stdtypes.html?highlight=dict%20get#dict.get): 
> Return the value for key if key is in the dictionary, else default. If default is not given, it defaults to None, so that this method never raises a KeyError.

In [290]:
for tkpair in tkpairs:
    print('Pair:', tkpair, '| Rank:', enc_self_bpe_ranks.get(tkpair, float('inf')))

    tkbigram = min(tkpairs, key = lambda tkpair: enc_self_bpe_ranks.get(tkpair, float('inf')))
print()
print('Returned pair:', tkbigram, '| Rank:',  enc_self_bpe_ranks.get(tkbigram))

Pair: ('h', 'e') | Rank: 2
Pair: ('Ġ', 't') | Rank: 0
Pair: ('o', 'r') | Rank: 17
Pair: ('t', 'h') | Rank: 144
Pair: ('e', 'o') | Rank: inf
Pair: ('r', 'y') | Rank: 307

Returned pair: ('Ġ', 't') | Rank: 0


In [301]:
print(tktpl)
frst, scnd = bigram     
print(frst, scnd)
nw_wd = []

i = 0
while i < len(tktpl):
    try:
        j = tktpl.index(frst, i)   # returns index of searched element (frst), starting at i
        nw_wd.extend(tktpl[i:j]) # append items from iterable to the end of the array
        print('try:', nw_wd, i, j)
        i = j
    except:
        nw_wd.extend(tktpl[i:])
        print('except:', nw_wd, i, j)
        break

    if tktpl[i] == frst and i < len(tktpl)-1 and tktpl[i+1] == scnd:
        nw_wd.append(frst+scnd)
        print('if', nw_wd, i, j)
        i += 2
    else:
        nw_wd.append(tktpl[i])
        print('else', nw_wd, i, j)
        i += 1

('Ġ', 't', 'h', 'e', 'o', 'r', 'y')
Ġ t
try: [] 0 0
if ['Ġt'] 0 0
except: ['Ġt', 'h', 'e', 'o', 'r', 'y'] 2 0


In [328]:
tktpl = tuple(token)
tkpairs = get_pairs(tktpl)
print('Token tuple:', tktpl)
print('Pairs:', tkpairs)
print('-'*30)

while True:
    bgrm = min(tkpairs, key = lambda pair: enc_self_bpe_ranks.get(pair, float('inf')))
    print('Chosen bigram:', bgrm, '| Rank:', enc_self_bpe_ranks.get(bgrm))
    if bgrm not in enc_self_bpe_ranks:                            
        break                                                       
    frst, scnd = bgrm                                          
    nw_wd = []                                                   
    i = 0
    while i < len(tktpl):
        try:
            j = tktpl.index(frst, i)   
            nw_wd.extend(tktpl[i:j]) 
            i = j
        except:
            nw_wd.extend(tktpl[i:])
            break
        if tktpl[i] == frst and i < len(tktpl)-1 and tktpl[i+1] == scnd:
            nw_wd.append(frst+scnd)
            i += 2
        else:
            nw_wd.append(tktpl[i])
            i += 1
    nw_wd = tuple(nw_wd)
    tktpl = nw_wd
    print('Token tuple updated:', tktpl)
    
    if len(tktpl) == 1:
        print()
        print('Length now', len(tktpl), ', breaking')
        break
    else:
        print('Repairing, pairs:', get_pairs(tktpl), end='\n\n')
        tkpairs = get_pairs(tktpl)

Token tuple: ('Ġ', 't', 'h', 'e', 'o', 'r', 'y')
Pairs: {('h', 'e'), ('Ġ', 't'), ('o', 'r'), ('t', 'h'), ('e', 'o'), ('r', 'y')}
------------------------------
Chosen bigram: ('Ġ', 't') | Rank: 0
Token tuple updated: ('Ġt', 'h', 'e', 'o', 'r', 'y')
Repairing, pairs: {('h', 'e'), ('o', 'r'), ('Ġt', 'h'), ('e', 'o'), ('r', 'y')}

Chosen bigram: ('h', 'e') | Rank: 2
Token tuple updated: ('Ġt', 'he', 'o', 'r', 'y')
Repairing, pairs: {('he', 'o'), ('Ġt', 'he'), ('o', 'r'), ('r', 'y')}

Chosen bigram: ('Ġt', 'he') | Rank: 6
Token tuple updated: ('Ġthe', 'o', 'r', 'y')
Repairing, pairs: {('Ġthe', 'o'), ('o', 'r'), ('r', 'y')}

Chosen bigram: ('o', 'r') | Rank: 17
Token tuple updated: ('Ġthe', 'or', 'y')
Repairing, pairs: {('Ġthe', 'or'), ('or', 'y')}

Chosen bigram: ('or', 'y') | Rank: 396
Token tuple updated: ('Ġthe', 'ory')
Repairing, pairs: {('Ġthe', 'ory')}

Chosen bigram: ('Ġthe', 'ory') | Rank: 4327
Token tuple updated: ('Ġtheory',)

Length now 1 , breaking


The result is turned again into a string, and will be unpacked by the `encode` function below.

In [336]:
a = ' '.join(tktpl)
print(a)
for b in a:
    print(b, enc_self_encoder[b])

Ġtheory
Ġ 220
t 83
h 71
e 68
o 78
r 81
y 88


---
### The *encode* function

In [213]:
def enc_encode(text):
    bpe_tokens = []                                 
    # for each token found by our regex (words, numbers, more than one space, punctuation)
    for token in re.findall(enc_self_pat, text):     
        # encode to utf-8 (char > int), then encode to byte, then join in a string
        token = ''.join(enc_self_byte_encoder[b] for b in token.encode('utf-8'))
        bpe_tokens.extend(enc_self_encoder[bpe_token] for bpe_token in enc_bpe(token).split(' '))
    return bpe_tokens

As a reminder, our byte encoder:

In [183]:
list(enc_self_byte_encoder.items())[:10]

[(33, '!'),
 (34, '"'),
 (35, '#'),
 (36, '$'),
 (37, '%'),
 (38, '&'),
 (39, "'"),
 (40, '('),
 (41, ')'),
 (42, '*')]

Steps:
- parse text for regexes, return as list;
- encode each word as a list of utf-8 codes;
- take these codes and transfer to bytes;
- apply the `bpe` function;
- encode the result.

In [187]:
bpe_tkns = []
for token in re.findall(enc_self_pat, text[:21]):
    print('Original token:', token)

    print('In UTF-8:', [b for b in token.encode('utf-8')])
    print('Now in bytes:', [enc_self_byte_encoder[b] for b in token.encode('utf-8')])
    token = ''.join(enc_self_byte_encoder[b] for b in token.encode('utf-8'))

    print('Result of the enc_bpe fn:', enc_bpe(token))
    print('Result of encoder:', [enc_self_encoder[bpe_token] for bpe_token in enc_bpe(token).split(' ')])
    bpe_tkns.extend(enc_self_encoder[bpe_token] for bpe_token in enc_bpe(token).split(' '))
    print()
nprint()
print('End result', bpe_tkns)

Original token: In
In UTF-8: [73, 110]
Now in bytes: ['I', 'n']
Result of the enc_bpe fn: I n
Result of encoder: [40, 77]

Original token:  probability
In UTF-8: [32, 112, 114, 111, 98, 97, 98, 105, 108, 105, 116, 121]
Now in bytes: ['Ġ', 'p', 'r', 'o', 'b', 'a', 'b', 'i', 'l', 'i', 't', 'y']
Result of the enc_bpe fn: Ġ p r o b a b i l i t y
Result of encoder: [220, 79, 81, 78, 65, 64, 65, 72, 75, 72, 83, 88]

Original token:  theory
In UTF-8: [32, 116, 104, 101, 111, 114, 121]
Now in bytes: ['Ġ', 't', 'h', 'e', 'o', 'r', 'y']
Result of the enc_bpe fn: Ġ t h e o r y
Result of encoder: [220, 83, 71, 68, 78, 81, 88]



-----------------
End result [40, 77, 220, 79, 81, 78, 65, 64, 65, 72, 75, 72, 83, 88, 220, 83, 71, 68, 78, 81, 88]


---
### The *decode* function

In [222]:
def enc_decode(tokens):
    # first decode from number to char
    # then 
    text = ''.join([enc_self_decoder[token] for token in tokens])
    text = bytearray([enc_self_byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
    return text

Method:
- decode numbers to bytes;
- decode bytes to utf-8;
- decode utf-8.

In [228]:
txt = "Hullo hullo"
tkns = enc_encode(txt)
print(txt)
print(tkns)
print(*zip(list(txt), tkns))

Hullo hullo
[39, 84, 75, 75, 78, 220, 71, 84, 75, 75, 78]
('H', 39) ('u', 84) ('l', 75) ('l', 75) ('o', 78) (' ', 220) ('h', 71) ('u', 84) ('l', 75) ('l', 75) ('o', 78)


As a reminder, our decoder and byte decoder:

In [229]:
print(list(enc_self_decoder.items())[:10])
print(list(enc_self_byte_decoder.items())[:10])

[(0, '!'), (1, '"'), (2, '#'), (3, '$'), (4, '%'), (5, '&'), (6, "'"), (7, '('), (8, ')'), (9, '*')]
[('!', 33), ('"', 34), ('#', 35), ('$', 36), ('%', 37), ('&', 38), ("'", 39), ('(', 40), (')', 41), ('*', 42)]


In [225]:
print([enc_self_decoder[tkn] for tkn in tkns])

['H', 'u', 'l', 'l', 'o', 'Ġ', 'h', 'u', 'l', 'l', 'o']


In [234]:
tkstr = ''.join([enc_self_decoder[tkn] for tkn in tkns])
print(bytearray([enc_self_byte_decoder[c] for c in tkstr]))
print(bytearray([enc_self_byte_decoder[c] for c in tkstr]).decode('utf-8', errors='replace'))

bytearray(b'Hullo hullo')
Hullo hullo


In [316]:
enc_decode(tkns)

'Hullo hullo'

---

### The final wrap-up: loading the full encoder

In [17]:
enc117 = get_encoder('117M')

In [119]:
nprint(list(enc117.encoder.items())[:50])
nprint(list(enc117.decoder.items())[:50])
nprint(list(enc117.byte_encoder.items())[:50])
nprint(list(enc117.byte_decoder.items())[:50])
nprint(list(enc117.bpe_ranks.items())[:50])

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19), ('5', 20), ('6', 21), ('7', 22), ('8', 23), ('9', 24), (':', 25), (';', 26), ('<', 27), ('=', 28), ('>', 29), ('?', 30), ('@', 31), ('A', 32), ('B', 33), ('C', 34), ('D', 35), ('E', 36), ('F', 37), ('G', 38), ('H', 39), ('I', 40), ('J', 41), ('K', 42), ('L', 43), ('M', 44), ('N', 45), ('O', 46), ('P', 47), ('Q', 48), ('R', 49)]

-----------------
[(0, '!'), (1, '"'), (2, '#'), (3, '$'), (4, '%'), (5, '&'), (6, "'"), (7, '('), (8, ')'), (9, '*'), (10, '+'), (11, ','), (12, '-'), (13, '.'), (14, '/'), (15, '0'), (16, '1'), (17, '2'), (18, '3'), (19, '4'), (20, '5'), (21, '6'), (22, '7'), (23, '8'), (24, '9'), (25, ':'), (26, ';'), (27, '<'), (28, '='), (29, '>'), (30, '?'), (31, '@'), (32, 'A'), (33, 'B'), (34, 'C'), (35, 'D'), (36, 'E'), (37, 'F'), (38, 'G'), (39, 'H'), (40, 'I')

In [127]:
text = "In probability theory and statistics, the Jensen–Shannon divergence is a method of measuring the similarity between two probability distributions. It is also known as information radius (IRad)[1] or total divergence to the average.[2] It is based on the Kullback–Leibler divergence, with some notable (and useful) differences, including that it is symmetric and it always has a finite value. The square root of the Jensen–Shannon divergence is a metric often referred to as Jensen-Shannon distance.[3][4][5]"
tok117 = enc117.encode(text)
print(tok117)
print()
txt117 = enc117.decode(tok117)
print(txt117)

[818, 12867, 4583, 290, 7869, 11, 262, 32623, 1906, 2484, 8825, 43366, 318, 257, 2446, 286, 15964, 262, 26789, 1022, 734, 12867, 24570, 13, 632, 318, 635, 1900, 355, 1321, 16874, 357, 4663, 324, 38381, 16, 60, 393, 2472, 43366, 284, 262, 2811, 3693, 17, 60, 632, 318, 1912, 319, 262, 509, 724, 1891, 1906, 3123, 571, 1754, 43366, 11, 351, 617, 12411, 357, 392, 4465, 8, 5400, 11, 1390, 326, 340, 318, 23606, 19482, 290, 340, 1464, 468, 257, 27454, 1988, 13, 383, 6616, 6808, 286, 262, 32623, 1906, 2484, 8825, 43366, 318, 257, 18663, 1690, 6412, 284, 355, 32623, 12, 2484, 8825, 5253, 3693, 18, 7131, 19, 7131, 20, 60]

In probability theory and statistics, the Jensen–Shannon divergence is a method of measuring the similarity between two probability distributions. It is also known as information radius (IRad)[1] or total divergence to the average.[2] It is based on the Kullback–Leibler divergence, with some notable (and useful) differences, including that it is symmetric and it always has a fi

In [269]:
print(list(enc117.bpe_ranks.items())[:10])

[(('Ġ', 't'), 0), (('Ġ', 'a'), 1), (('h', 'e'), 2), (('i', 'n'), 3), (('r', 'e'), 4), (('o', 'n'), 5), (('Ġt', 'he'), 6), (('e', 'r'), 7), (('Ġ', 's'), 8), (('a', 't'), 9)]
