In [61]:
import base64
import ast
import json
import regex as re

pattern = re.compile(r""" ?ܘ(?=\p{L}+)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
with open('merges_6000.json', 'r') as f:
    merges_converted = json.load(f)

with open('vocabulary_6000.json', 'r') as f:
    vocab_converted = json.load(f)

def convert_keys_and_values(d):
    converted_dict = {}
    for k, v in d.items():
        
        
        # Convert string representation of tuple back to tuple
        # Convert string representation of integers back to integers
        if k.startswith('(') and k.endswith(')'):
            m = eval(k)  
        elif k.isdigit():
            m = int(k)
        else:
            m = k    
            
        
        # Convert string values back to bytes if necessary (assuming some indication that it was a byte string)
        # This example assumes no special encoding for bytes, but you can customize this as needed.
        byte_value = v.encode('utf-8') if isinstance(v, str) else v
        
        converted_dict[m] = byte_value
        
        

    return converted_dict

vocab = convert_keys_and_values(vocab_converted)
merges = convert_keys_and_values(merges_converted)

print (vocab)
print (vocab_converted)
print (merges)
print (merges_converted)

{256: b'\xdc\x90', 257: b'\xdc\x98', 258: b'\xdc\x9d', 259: b'\xdc\xa2', 260: b'\xdc\xa0', 261: b'\xdc\xa1', 262: b'\xdc\x95', 263: b'\xdc\xaa', 264: b'\xdc\x92', 265: b'\xdc\xac', 266: b'\xdc\x97', 267: b' \xdc\x98', 268: b'\xdc\x9f', 269: b' \xdc\x95', 270: b'\xdc\xa5', 271: b' \xdc\x90', 272: b' \xdc\xa0', 273: b'\xdc\xab', 274: b' \xdc\xa1', 275: b'\xdc\x9a', 276: b' \xdc\x92', 277: b'\xdc\xa9', 278: b'\xdc\x98\xdc\xa2', 279: b'\xdc\xa6', 280: b'\xdc\x9d\xdc\xa2', 281: b'\xdc\xa3', 282: b'\xdc\xac\xdc\x90', 283: b'\xdc\xaa\xdc\x9d', 284: b' \xdc\xa5', 285: b'\xdc\x93', 286: b'\xdc\x9b', 287: b'\xdc\xa2\xdc\x90', 288: b' \xdc\x97', 289: b'\xdc\xa0\xdc\x90', 290: b'\xdc\xaa\xdc\x90', 291: b'\xdc\xa1\xdc\x90', 292: b' \xdc\xa1\xdc\xa2', 293: b'\xdc\x99', 294: b' \xdc\xa2', 295: b'\xdc\x9d\xdc\xac', 296: b'\xdc\xa0\xdc\x97', 297: b' \xdc\xac', 298: b'\xdc\x9d\xdc\x90', 299: b'\xdc\x97\xdc\x9d', 300: b' \xdc\x95\xdc\x90', 301: b' \xdc\x9f', 302: b' \xdc\x9a', 303: b'\xdc\xaa\xdc\x9d\xdc

In [62]:
def get_stats(ids):
    counts = {}
    for id in ids:    
        for pair in zip(id, id[1:]): # Pythonic way to iterate consecutive elements
            if pair[1] != 220:  # Check to avoid pairs ending with 220
                counts[pair] = counts.get(pair, 0) + 1
    return counts


def simple_get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        if pair[1] != ord(' ') and pair[1] != 220:  # Check to avoid pairs ending with 220
            counts[pair] = counts.get(pair, 0) + 1
    return counts

#top_pair = max(stats, key=stats.get)

In [63]:
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  for sublist in ids:
        i = 0
        new_sublist = []
        while i < len(sublist):
            # if we are not at the very last position AND the pair matches, replace it
            if i < len(sublist) - 1 and sublist[i] == pair[0] and sublist[i + 1] == pair[1]:
                new_sublist.append(idx)
                i += 2
            else:
                new_sublist.append(sublist[i])
                i += 1
        newids.append(new_sublist)
  return newids

def simple_merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

def length(ids):
    return sum(len(id) for id in ids)



### encoding

The other way around: Given a string, what are the tokens?


In [64]:
def encode_simple(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = simple_get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = simple_merge(tokens, pair, idx)
  return tokens

def encode(texts):
    encoded_texts = []
    for text in texts:
        # Convert text to UTF-8 bytes and then to a list of integers
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = simple_get_stats(tokens)
            pair = min(stats, key=lambda p: merges.get(p, float("inf")))
            if pair not in merges:
                break  # Nothing else can be merged
            idx = merges[pair]
            tokens = simple_merge(tokens, pair, idx)
        encoded_texts.append(tokens)
    return encoded_texts


def decode_simple(ids):
  # given ids (list of integers), return Python string
  tokens = b".".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

def decode(ids):
    text_concatenate = ""
    for word in ids:
        # Convert list of integer IDs into bytes using the vocab dictionary
        tokens = b".".join(vocab[idx] for idx in word)
        # Decode bytes to a string, replacing errors with a placeholder
        text = tokens.decode("utf-8", errors="replace")
        # Concatenate the decoded text with a space
        text_concatenate += text + "."
    return text_concatenate.strip()  # Remove trailing space

In [67]:
text = "ܘܡܠܟܐ ܕܘܝܕ ܣܐܒ ܘܥܠ ܒܫܢܝܐ ܘܡܟܣܝܢ ܗܘܘ ܠܗ ܒܠܒܘܫܐ ܘܠܐ ܫܚܢ ܘܐܡܪܘ ܠܗ ܥܒܕܘܗܝ ܗܐ ܥܒܕܝܟ ܩܕܡܝܟ ܢܒܥܘܢ ܠܡܪܢ"

text_words = re.findall(pattern, text) # devide the text into words according to gpt2pat pattern

words_tokens = []
for i in text_words:
    token_word = i.encode("utf-8")
    token = list(map(int, token_word))
    words_tokens.append(token)
print(words_tokens[:2])

text_encoded = encode(text_words)
print (text_encoded)
text_decoded = decode(text_encoded)
print (text_decoded)


[[220, 152], [220, 161, 220, 160, 220, 159, 220, 144]]
[[257], [1196], [836], [5562], [267], [367], [5563], [267], [261, 647, 280], [568], [333], [4677], [267], [289], [307, 1596], [267], [770], [333], [856], [553], [1974], [819], [294, 2156], [5564]]
ܘ.ܡܠܟܐ. ܕܘܝܕ. ܣܐܒ. ܘ.ܥܠ. ܒܫܢܝܐ. ܘ.ܡ.ܟܣ.ܝܢ. ܗܘܘ. ܠܗ. ܒܠܒܘܫܐ. ܘ.ܠܐ. ܫ.ܚܢ. ܘ.ܐܡܪܘ. ܠܗ. ܥܒܕܘܗܝ. ܗܐ. ܥܒܕܝܟ. ܩܕܡܝܟ. ܢ.ܒܥܘܢ. ܠܡܪܢ.
