# Notes and Questions on Tokenization
## Unicode encodings
The unicode standard defines a mapping from characters to code points (integers). When you write a string in python it's represented as a sequence of integers which you can see using `ord`.

In [35]:
print([s for s in "Ã©ðŸ˜Š"])
print([ord(c) for c in "Ã©ðŸ˜Š"])

['Ã©', 'ðŸ˜Š']
[233, 128522]


There are close to 300,000 assigned code points in the Unicode standard, which is impractical for modelling. It's common to convert squences of unicode points to sequences of bytes. Furthermore we'd often run into out-of-vocabulary words; the vocabulary would be unbounded.

Because there are only `2^8=256` unique byte values (8 bits in a byte) and every unicode character can be represented as a squence of bytes, we can represent any corpus with a vocabulary size of 256, at the cost of longer sequences. This alliviates the OOV problem.


There are three common encodings for unicode, which represent the code points as a sequence of one or more bytes, they are: UTF-8 which makes up 98% of the internet, UTF-16 and UTF-32.

**Q**: What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than
UTF-16 or UTF-32?

In [13]:
x = "hello world it's me!"
x_utf8 = x.encode("utf-8")
x_utf16 = x.encode("utf-16")
x_utf32 = x.encode("utf-32")

print(x_utf8)
print(x_utf16)
print(x_utf32)

print(len(x_utf8))
print(len(x_utf16))
print(len(x_utf32))

print(
    [x for x in x_utf8]
)

print(
    [x for x in x_utf16]
)

print(
    [x for x in x_utf32]
)

b"hello world it's me!"
b"\xff\xfeh\x00e\x00l\x00l\x00o\x00 \x00w\x00o\x00r\x00l\x00d\x00 \x00i\x00t\x00'\x00s\x00 \x00m\x00e\x00!\x00"
b"\xff\xfe\x00\x00h\x00\x00\x00e\x00\x00\x00l\x00\x00\x00l\x00\x00\x00o\x00\x00\x00 \x00\x00\x00w\x00\x00\x00o\x00\x00\x00r\x00\x00\x00l\x00\x00\x00d\x00\x00\x00 \x00\x00\x00i\x00\x00\x00t\x00\x00\x00'\x00\x00\x00s\x00\x00\x00 \x00\x00\x00m\x00\x00\x00e\x00\x00\x00!\x00\x00\x00"
20
42
84
[104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 32, 105, 116, 39, 115, 32, 109, 101, 33]
[255, 254, 104, 0, 101, 0, 108, 0, 108, 0, 111, 0, 32, 0, 119, 0, 111, 0, 114, 0, 108, 0, 100, 0, 32, 0, 105, 0, 116, 0, 39, 0, 115, 0, 32, 0, 109, 0, 101, 0, 33, 0]
[255, 254, 0, 0, 104, 0, 0, 0, 101, 0, 0, 0, 108, 0, 0, 0, 108, 0, 0, 0, 111, 0, 0, 0, 32, 0, 0, 0, 119, 0, 0, 0, 111, 0, 0, 0, 114, 0, 0, 0, 108, 0, 0, 0, 100, 0, 0, 0, 32, 0, 0, 0, 105, 0, 0, 0, 116, 0, 0, 0, 39, 0, 0, 0, 115, 0, 0, 0, 32, 0, 0, 0, 109, 0, 0, 0, 101, 0, 0, 0, 33, 0, 0, 0]


**A**: `UTF-8` represents the most common unicode points with a single bytes, and used longer sequences (2 to 4 bytes) for less common points, like emojis (it's a variable width encoding);  This makes it more memory efficient than `utf-16` or `utf-32`, which use 2 or 4 bytes to represent every unicode point, respectively (fixed-width).

In turn, means sequence lengths are shortest in `utf-8` which is important when calculating the attention mechanism, since it's time complexity scales as `O(n^2)`

**Q**:
Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into
a Unicode string. Why is this function incorrect? Provide an example of an input byte string
that yields incorrect results.

In [34]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

x = "Ã©"
x_utf8 = x.encode("utf-8")

print(x)
print(x_utf8)
print(decode_utf8_bytes_to_str_wrong(x_utf8))


Ã©
b'\xc3\xa9'


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xc3 in position 0: unexpected end of data

The function fails for unicode characters which are represented by more than one byte under the `UTF-8` encoding because it iterates over individual bytes. The character Ã© is represented by 2 bytes, the leading byte `0xC3` and the continuation byte `0xA9`; the function tries to decode the leading byte into a unicode point which does not exist.

## Subword Tokenization
Expressing every character in a text as byte sequences produces sequences which a prohibitivley large

In [38]:
s = "low low low low low lower lower widest widest widest newest newest newest newest newest newest"

In [39]:
s

'low low low low low lower lower widest widest widest newest newest newest newest newest newest'

In [92]:
vocab = {x: bytes([x]) for x in range(256)}

In [93]:
vocab


{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [None]:
from collections import Counter

counter = Counter(s.split(" "))

Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2})

In [68]:
pretoken_counts = {tuple(k.encode("utf-8")): v for k, v in counter.items()}
pretoken_counts

{(108, 111, 119): 5,
 (108, 111, 119, 101, 114): 2,
 (119, 105, 100, 101, 115, 116): 3,
 (110, 101, 119, 101, 115, 116): 6}

In [None]:
def get_pretoken_counts(text: str):
    counter = Counter(text.split(" "))
    return {tuple(k.encode("utf-8")): v for k, v in counter.items()}



b'low low low low low lower lower widest widest widest newest newest newest newest newest newest'


In [79]:
pretoken_counts

{(108, 111, 119): 5,
 (108, 111, 119, 101, 114): 2,
 (119, 105, 100, 101, 115, 116): 3,
 (110, 101, 119, 101, 115, 116): 6}

In [None]:
from collections import defaultdict


def merge_tuple(t: tuple, merges: list):
    new_keys = []
    for i in range(len(k)-1):
        t = (k[i], k[i+1])
        if t in merges:
           new_keys.append(t[0] + t[1])
           i+=1  # need to double step, so increment i by an additional 1
        else:
            t.append(t[0])
            t.append(t[1])
    return tuple(new_keys)

def update_pretoken_counts(pretoken_counts, merges: list):
    d = defaultdict(int)
    for key, value in pretoken_counts.items():
        key = merge(key)
        d[key] = value

    return d
        for i in range(len(k)-1):
            t = (k[i], k[i+1])
            if t in merges:  # if tuple in list of tuples to be merged
                new_key = k[i] + k[i+1]
                


def train_bpe(
    pretoken_counts: dict[tuple[int], int],
    vocab: dict[int, bytes],
    n_merges: int,
):
    merges = {}

    for i in range(n_merges):
        byte_pair_counts = defaultdict(int)
        new_index = 256+i
        for k, value in pretoken_counts.items():
            for i in range(len(k)-1):
                t = (k[i], k[i+1])
                byte_pair_counts[t] += value

        most_frequent_byte_pair = max(byte_pair_counts, key=byte_pair_counts.get)
        print(most_frequent_byte_pair)
        merges.append(most_frequent_byte_pair)
        vocab[new_index] = vocab[most_frequent_byte_pair[0]] + vocab[most_frequent_byte_pair[1]]
        # update pretoken_counts
        update_pretoken_counts(pretoken_counts, merges)


    return byte_pair_counts, vocab, merges

IndentationError: unexpected indent (2513426619.py, line 23)

In [155]:
byte_pair_counts, vocab, merges = train_bpe(
    pretoken_counts=pretoken_counts,
    vocab = vocab,
    n_merges=3
)
merges


(101, 115)
(101, 115)
(101, 115)


{(101, 115): 258}

In [143]:
byte_pair_counts

defaultdict(int,
            {(108, 111): 7,
             (111, 119): 7,
             (119, 101): 8,
             (101, 114): 2,
             (119, 105): 3,
             (105, 100): 3,
             (100, 101): 3,
             (101, 115): 9,
             (115, 116): 9,
             (110, 101): 6,
             (101, 119): 6})

In [144]:
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [None]:
def merge_tuple(t: tuple, merges: list):
    print(f"Original tuple: {t}")
    new_keys = []
    for i in range(len(t)-1):
        pair = (t[i], t[i+1])
        print(f"Checking if {pair} in tuple")
        if pair in merges:
            print(f"{pair} in original keys! Merging to {pair[0] + pair[1]}")
            new_keys.append(pair[0] + pair[1])
        else:
            new_keys.append(pair[0])
            new_keys.append(pair[1])
    return tuple(new_keys)

In [190]:
def merge(key: tuple[int], merged: list[int]):
    l = list(key)
    return [t for t in list(zip(l, l[1:])) if t not in merged else -990]

tuples = (119, 105, 100, 101, 115, 116)
merges = [(105, 100)]
merge(tuples, merges)


SyntaxError: invalid syntax (1916121596.py, line 3)

In [None]:
l1 = [1, 3, 5, 7, 9]
def merge(key: tuple[int], merged: list[int]):
    l = list(key)
    return [t for t in list(zip(l, l[1:])) if not in merged else -990]


[(1, 3), (3, 5), (5, 7), (7, 9)]

In [172]:
l = [(105, 100)]
(105, 100) in l

True

##