# Simple implementation of Byte Pair Encoding Algorithm

A simple implmeentation of the Byte Pair Encoding (BPE) in language model tokenizers

In [6]:
import collections
from collections import defaultdict

The BPE algorithm operates at the character level and optionally at the byte level for encoding. In this implementation, UTF-8 encoding is utilized as the character-level encoding mechanism.

In [25]:
text = 'The trial is a 12-week single-arm pilot study. Twenty female breast cancer survivors will engage in a supervised moderate intensity walking intervention in small groups in a nature reserve for 50 minutes three times per week. Data will be collected at baseline and end of study, and include assessment of inflammatory cytokines and anti-inflammatory myokines (TNF-α, IL-1ß, IL-6, CRP, TGF-ß, IL-10, IL-13), as well as ageing (DNA methylation, ageing genes) biomarkers; surveys (Patient-Reported Outcomes Measurement Information System-29, Functional Assessment of Cancer Therapy-General, Post-Traumatic Growth Inventory); and fitness assessments (6 min Walk Test, Grip-Strength, One Repetition-Maximum Leg Press). Participants will also complete weekly surveys assessing social support and participate in an exit interview. This is an important first step for future research on the influence of exercise environment on cancer survivor PA outcomes.'
tokens = text.encode('utf-8')
#vocab = map(int)
print(len(tokens))
vocab = {'tokens_ids' : {chr(token) : token for token in tokens},
        'ids_tokens' : {token : chr(token) for token in tokens}}
len(vocab['tokens_ids'])

951


56

The next step is to find the frequency of occurrence of pairs of consecutive characters.

In [11]:
#calculates the frequency of consecutive characters
def get_freq(tokens):
    freq = defaultdict(int)
    for pair in zip(tokens, tokens[1:]):
        freq[pair] += 1
    return freq
get_freq(tokens)

defaultdict(int,
            {(84, 104): 3,
             (104, 101): 3,
             (101, 32): 17,
             (32, 116): 4,
             (116, 114): 2,
             (114, 105): 2,
             (105, 97): 2,
             (97, 108): 9,
             (108, 32): 8,
             (32, 105): 13,
             (105, 115): 5,
             (115, 32): 17,
             (32, 97): 20,
             (97, 32): 4,
             (32, 49): 1,
             (49, 50): 1,
             (50, 45): 1,
             (45, 119): 1,
             (119, 101): 5,
             (101, 101): 4,
             (101, 107): 3,
             (107, 32): 2,
             (32, 115): 12,
             (115, 105): 3,
             (105, 110): 21,
             (110, 103): 7,
             (103, 108): 1,
             (108, 101): 4,
             (101, 45): 1,
             (45, 97): 1,
             (97, 114): 5,
             (114, 109): 2,
             (109, 32): 2,
             (32, 112): 3,
             (112, 105): 1,
             (105, 108):

A merge function combines the pair with the highest frequency and assigns a new identifier to the merged token.

In [12]:
#this function gets the pair and creates a new id for the new token
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(i)
            i+= 1
    return new_ids


The number of merges is a hyperparameter defined during the training process. A higher number of merges leads to a larger vocabulary size and a more compressed tokenized sequence length and embedding dimension. If UTF-8 encoding is used, the initial vocabulary size would be 256. Adding 20 merges would increase the vocabulary size to 276.

In [28]:
def byte_pair_encoding(tokens, num_merges, vocab):
    ids = list(tokens)
    for i in range(num_merges):
        pair_freq = get_freq(ids)
        top_pair = max(pair_freq, key = pair_freq.get)
        idx = 256 + i
        vocab['tokens_ids'][chr(top_pair[0]) + chr(top_pair[1])] = idx
        vocab['ids_tokens'][idx] = chr(top_pair[0]) + chr(top_pair[1])
        ids = merge(ids, top_pair, idx)
    return ids, vocab

byte_pair_encoding(tokens, 20, vocab)    

([275,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  128,
  129,
  130,
  131,
  132,
  133,
  134,
  135,
  136,
  137,
  138,
  139,
  140,
  141,
  142,
  143,
  144,
  145,
  146,
  147,
  148,
  149,
  150,
  151,
  152,
  153,
  154,
  155,
  156,
  157,
  158,


In [29]:
#calculate compression rate
print(f"the compression rate is {len(tokens)/len(ids)}")

the compression rate is 1.0554938956714761
