In [2]:
!wc -l ../tokenizers/uralic.txt

wc: ../tokenizers/uralic.txt: open: No such file or directory


In [2]:
!pip install -q tqdm


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @created: 20.09.2024
# @author: Aleksey Komissarov
# @contact: ad3002@gmail.com

import json
import sys
from tqdm import tqdm

def generate_bytes_char_mapping():
    """
    Generates byte-to-character and character-to-byte mappings consistent with the Rust encoder.

    Returns:
        byte_to_char (dict): Maps each byte (0-255) to a unique Unicode character.
        char_to_byte (dict): Reverse mapping from Unicode characters to bytes.
    """
    # Step 1: Define the initial bytes (188 bytes)
    initial_bytes = list(range(0x21, 0x7F))  # 0x21 (!) to 0x7E (~)
    initial_bytes += list(range(0xA1, 0xAD))  # 0xA1 to 0xAC
    initial_bytes += list(range(0xAE, 0x100))  # 0xAE to 0xFF

    # Step 2: Identify missing bytes (68 bytes)
    all_bytes = set(range(256))
    present_bytes = set(initial_bytes)
    missing_bytes = sorted(all_bytes - present_bytes)

    # Step 3: Create mappings
    byte_to_char = {}
    char_to_byte = {}

    # Map initial bytes to their direct Unicode equivalents
    for byte in initial_bytes:
        char = chr(byte)
        byte_to_char[byte] = char
        char_to_byte[char] = byte

    # Map missing bytes to unique Unicode characters starting from U+0100
    start_code_point = 0x0100  # U+0100 (Ā)
    for i, byte in enumerate(missing_bytes):
        char = chr(start_code_point + i)
        byte_to_char[byte] = char
        char_to_byte[char] = byte

    return byte_to_char, char_to_byte

# Generate the mappings
byte_to_char, char_to_byte = generate_bytes_char_mapping()


def byte_level_decode(encoded_string, char_to_byte, encoding='utf-8'):
    """
    Decodes a ByteLevel encoded string back to the original string using the provided mapping.

    Parameters:
        encoded_string (str): The ByteLevel encoded string.
        char_to_byte (dict): Mapping from Unicode characters to byte values.
        encoding (str): The encoding to use for the output string (default: 'utf-8').

    Returns:
        str: The decoded original string.
    """
    decoded_bytes = bytearray()
    for char in encoded_string:
        if char in char_to_byte:
            decoded_bytes.append(char_to_byte[char])
        else:
            raise ValueError(f"Unknown character in encoded string: {char}")
            # Alternatively, use a placeholder:
            # decoded_bytes.append(ord('?'))

    try:
        return decoded_bytes.decode(encoding)
    except UnicodeDecodeError as e:
        raise UnicodeDecodeError(f"Failed to decode byte sequence: {e}")

def byte_level_decode_custom(encoded_string, char_to_byte, encoding='utf-8'):
    """
    Decodes a ByteLevel encoded string back to the original string using the provided mapping.
    Handles invalid UTF-8 sequences by replacing them with � and incomplete bytes at the end with <0xXX>.

    Parameters:
        encoded_string (str): The ByteLevel encoded string.
        char_to_byte (dict): Mapping from Unicode characters to byte values.
        encoding (str): The encoding to use for the output string (default: 'utf-8').

    Returns:
        str: The decoded original string with � for invalid sequences and <0xXX> for incomplete bytes.
    """
    # Step 1: Convert encoded characters back to bytes
    byte_seq = bytearray()
    for char in encoded_string:
        if char in char_to_byte:
            byte_seq.append(char_to_byte[char])
        else:
            # Handle unknown characters by replacing with � (0xFF)
            byte_seq.append(0xFF)  # You can choose a different placeholder if needed

    # Step 2: Iterate through the byte sequence to decode UTF-8 characters
    decoded_chars = []
    i = 0
    n = len(byte_seq)

    while i < n:
        byte = byte_seq[i]
        # Single-byte (ASCII)
        if byte <= 0x7F:
            decoded_chars.append(chr(byte))
            i += 1
        # Two-byte sequence
        elif 0xC0 <= byte <= 0xDF:
            if i + 1 < n:
                next_byte = byte_seq[i + 1]
                if 0x80 <= next_byte <= 0xBF:
                    try:
                        char = bytes(byte_seq[i:i+2]).decode(encoding)
                        decoded_chars.append(char)
                        i += 2
                        continue
                    except UnicodeDecodeError:
                        pass
            # Invalid continuation byte
            decoded_chars.append('�')
            i += 1
        # Three-byte sequence
        elif 0xE0 <= byte <= 0xEF:
            if i + 2 < n:
                next1 = byte_seq[i + 1]
                next2 = byte_seq[i + 2]
                if 0x80 <= next1 <= 0xBF and 0x80 <= next2 <= 0xBF:
                    try:
                        char = bytes(byte_seq[i:i+3]).decode(encoding)
                        decoded_chars.append(char)
                        i += 3
                        continue
                    except UnicodeDecodeError:
                        pass
            # Invalid continuation bytes
            decoded_chars.append('�')
            i += 1
        # Four-byte sequence
        elif 0xF0 <= byte <= 0xF7:
            if i + 3 < n:
                next1 = byte_seq[i + 1]
                next2 = byte_seq[i + 2]
                next3 = byte_seq[i + 3]
                if 0x80 <= next1 <= 0xBF and 0x80 <= next2 <= 0xBF and 0x80 <= next3 <= 0xBF:
                    try:
                        char = bytes(byte_seq[i:i+4]).decode(encoding)
                        decoded_chars.append(char)
                        i += 4
                        continue
                    except UnicodeDecodeError:
                        pass
            # Invalid continuation bytes
            decoded_chars.append('�')
            i += 1
        else:
            # Invalid start byte
            decoded_chars.append('�')
            i += 1

    # Step 3: Check for incomplete bytes at the end
    # In this implementation, incomplete bytes are already handled by replacing with �
    # If you want to represent incomplete bytes specifically, additional logic is needed

    return ''.join(decoded_chars)



def load_vocab(tokenizer_file):

    vocab = {}

    try:
        with open(tokenizer_file, "r") as fr:
            tokenizer = json.load(fr)
    except json.decoder.JSONDecodeError:
        print(f"Bad tokenizer file: {tokenizer_file}")
        print("Please provide a valid tokenizer file.")
        print("Here are the first few lines of the file:")
        with open(tokenizer_file, "r") as fr:
            for i, line in enumerate(fr):
                print(line)
                if i > 10:
                    break
        sys.exit(1)
    if not "model" in tokenizer:
        if "vocab" in tokenizer:
            tokenizer["model"] = {"vocab": tokenizer["vocab"]}
        if "mama" in tokenizer:
            tokenizer["model"] = {
                "vocab": tokenizer
            }
        with open(tokenizer_file, "w") as fw:
            json.dump(tokenizer, fw, indent=2)

    ### rare case with negative ranks
    if "model" in tokenizer and "vocab" in tokenizer["model"]:
        if isinstance(tokenizer["model"]["vocab"], list):
            print("Bad format for vocab")
            sys.exit(1)
    if not "model" in tokenizer or not "vocab" in tokenizer["model"]:
        print("Bad format for vocab")
        sys.exit(1)
    
    with open(tokenizer_file) as fh:
        text_data = fh.read()

    replacers = [
         ( text_data.count("▁"), "▁"),
         ( text_data.count("Ġ"), "Ġ"),
         ( text_data.count("\u0120"), "\u0120"),
         ( text_data.count("\t"), "\t"),
         ( text_data.count("\u2581"), "\u2581"),
    ]
    replacers.sort()
    replace = replacers[-1][1]
    if replace == "Ġ":
        replace = None
       
        
    should_be_fixed = "ма" not in text_data
    for raw_token, rid, in tokenizer["model"]["vocab"].items():
        
        rr = raw_token
        if replace and raw_token.startswith(replace) and len(raw_token) > 1:
          if should_be_fixed:
            raw_token = "Ġ" + replace.join(raw_token.split(replace)[1:])
          else:
            raw_token = " " + replace.join(raw_token.split(replace)[1:])
            
        if raw_token.lower().startswith("<0x"):
          token = byte_to_char[eval(raw_token[1:-1])]
          vocab[token] = rid
          continue
    
        if should_be_fixed:
          token =  byte_level_decode_custom(raw_token, char_to_byte)
          if [1 for x in token if ord(x) == 65533]:
            token = f"<0y{raw_token}>"
        else:
          if raw_token in char_to_byte:
            token = str(hex(char_to_byte[raw_token])).upper()
          else:
            token = raw_token
        try:
          assert token not in vocab
        except:
          print(f"ERROR-{rid}-{token}-|-{raw_token}-{rr}-{len(raw_token)}")
          print([ord(x) for x in token])
          input("?")

        vocab[token] = rid

    if "added_tokens" in tokenizer:
      for d in tokenizer["added_tokens"]:
        raw_token = d["content"]
        rid = d["id"]
        if not raw_token in vocab:
          vocab[raw_token] = rid

    return vocab

def load_vocab(tokenizer_file):

    vocab = {}

    try:
        with open(tokenizer_file, "r") as fr:
            tokenizer = json.load(fr)
    except json.decoder.JSONDecodeError:
        print(f"Bad tokenizer file: {tokenizer_file}")
        print("Please provide a valid tokenizer file.")
        print("Here are the first few lines of the file:")
        with open(tokenizer_file, "r") as fr:
            for i, line in enumerate(fr):
                print(line)
                if i > 10:
                    break
        sys.exit(1)
    if not "model" in tokenizer:
        if "vocab" in tokenizer:
            tokenizer["model"] = {"vocab": tokenizer["vocab"]}
        if "mama" in tokenizer:
            tokenizer["model"] = {
                "vocab": tokenizer
            }
        with open(tokenizer_file, "w") as fw:
            json.dump(tokenizer, fw, indent=2)

    ### rare case with negative ranks
    if "model" in tokenizer and "vocab" in tokenizer["model"]:
        if isinstance(tokenizer["model"]["vocab"], list):
            print("Bad format for vocab")
            sys.exit(1)
    if not "model" in tokenizer or not "vocab" in tokenizer["model"]:
        print("Bad format for vocab")
        sys.exit(1)
    
    with open(tokenizer_file) as fh:
        text_data = fh.read()

    replacers = [
         ( text_data.count("▁"), "▁"),
         ( text_data.count("Ġ"), "Ġ"),
         ( text_data.count("\u0120"), "\u0120"),
         ( text_data.count("\t"), "\t"),
         ( text_data.count("\u2581"), "\u2581"),
    ]
    replacers.sort()
    replace = replacers[-1][1]
    if replace == "Ġ":
        replace = None
       
        
    should_be_fixed = "ма" not in text_data
    for raw_token, rid, in tqdm(tokenizer["model"]["vocab"].items()):
        
        rr = raw_token
        if replace and raw_token.startswith(replace) and len(raw_token) > 1:
          if should_be_fixed:
            raw_token = "Ġ" + replace.join(raw_token.split(replace)[1:])
          else:
            raw_token = " " + replace.join(raw_token.split(replace)[1:])
            
        if raw_token.lower().startswith("<0x"):
          token = byte_to_char[eval(raw_token[1:-1])]
          vocab[token] = rid
          continue
    
        if should_be_fixed:
          token =  byte_level_decode_custom(raw_token, char_to_byte)
          if [1 for x in token if ord(x) == 65533]:
            token = f"<0y{raw_token}>"
        else:
          if raw_token in char_to_byte:
            token = str(hex(char_to_byte[raw_token])).upper()
          else:
            token = raw_token
        try:
          assert token not in vocab
        except:
          print(f"ERROR-{rid}-{token}-|-{raw_token}-{rr}-{len(raw_token)}")
          print([ord(x) for x in token])
          input("?")

        vocab[token] = rid

    if "added_tokens" in tokenizer:
      for d in tokenizer["added_tokens"]:
        raw_token = d["content"]
        rid = d["id"]
        if not raw_token in vocab:
          vocab[raw_token] = rid

    return vocab

In [4]:
# from tokenizer import load_vocab

In [7]:
tokenizer_files = {
    "estonian": "../tokenizers/et.wiki_paragraphs.2024-05.json",
    "finnish": "../tokenizers/fi.wiki_paragraphs.2023-05.json",
    "hungarian": "../tokenizers/hu.wiki_paragraphs.2024-05.json",
    "se": "../tokenizers/se.wiki_paragraphs.2024-05.json",
    "uralic": "../tokenizers/uralic.wiki_paragraphs.2024-05.json",
}

In [8]:
vocabs = {k: load_vocab(v) for k, v in tokenizer_files.items()}

100%|█████████████████████████████████████████████████████████| 256000/256000 [00:00<00:00, 287965.26it/s]
100%|█████████████████████████████████████████████████████████| 256000/256000 [00:00<00:00, 292348.94it/s]
100%|█████████████████████████████████████████████████████████| 256000/256000 [00:00<00:00, 274303.70it/s]
100%|███████████████████████████████████████████████████████████| 93188/93188 [00:00<00:00, 293944.66it/s]
100%|█████████████████████████████████████████████████████████| 256000/256000 [00:00<00:00, 292109.71it/s]


In [9]:
all_tokens = set()
for k, v in vocabs.items():
    for token in v:
        all_tokens.add(token)
len(all_tokens)

785115

In [10]:
token_to_ranks = {}
for i, (k, v) in enumerate(vocabs.items()):
    print(k, len(v))
    for token in v:
        token_to_ranks.setdefault(token, [0] * len(vocabs))
        token_to_ranks[token][i] = v[token] + 1
len(token_to_ranks)

estonian 256000
finnish 256000
hungarian 256000
se 93188
uralic 256000


785115

In [11]:
for i, token in enumerate(token_to_ranks):
    if i > 300:
        break
    print(token, token_to_ranks[token])

<s> [1, 1, 1, 1, 1]
<pad> [2, 2, 2, 2, 2]
</s> [3, 3, 3, 3, 3]
<unk> [4, 4, 4, 4, 4]
<mask> [5, 5, 5, 5, 5]
A [6, 6, 6, 19, 6]
B [7, 7, 7, 20, 7]
C [8, 8, 8, 21, 8]
D [9, 9, 9, 22, 9]
E [10, 10, 10, 23, 10]
F [11, 11, 11, 24, 11]
G [12, 12, 12, 25, 12]
H [13, 13, 13, 26, 13]
I [14, 14, 14, 27, 14]
J [15, 15, 15, 28, 15]
K [16, 16, 16, 29, 16]
L [17, 17, 17, 30, 17]
M [18, 18, 18, 31, 18]
N [19, 19, 19, 32, 19]
O [20, 20, 20, 33, 20]
P [21, 21, 21, 34, 21]
Q [22, 22, 22, 35, 22]
R [23, 23, 23, 36, 23]
S [24, 24, 24, 37, 24]
T [25, 25, 25, 38, 25]
U [26, 26, 26, 39, 26]
V [27, 27, 27, 40, 27]
W [28, 28, 28, 41, 28]
X [29, 29, 29, 42, 29]
Y [30, 30, 30, 43, 30]
Z [31, 31, 31, 44, 31]
a [32, 32, 32, 45, 32]
b [33, 33, 33, 46, 33]
c [34, 34, 34, 47, 34]
d [35, 35, 35, 48, 35]
e [36, 36, 36, 49, 36]
f [37, 37, 37, 50, 37]
g [38, 38, 38, 51, 38]
h [39, 39, 39, 52, 39]
i [40, 40, 40, 53, 40]
j [41, 41, 41, 54, 41]
k [42, 42, 42, 55, 42]
l [43, 43, 43, 56, 43]
m [44, 44, 44, 57, 44]
n [45, 45, 

In [12]:
def escape_token(token):
    return token.replace('\t', '\\t').replace('\r', '\\r').replace('\n', '\\n')

with open("../tokenizers/token_ranks", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        escaped_token = escape_token(token)
        f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")

In [13]:
with open("../tokenizers/token_ranks_uniq.tsv", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        if sum(1 for rank in ranks[:-1] if rank != 0) == 1:
            escaped_token = escape_token(token)
            f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")

In [14]:
with open("../tokenizers/token_ranks_not_uniq.tsv", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        if sum(1 for rank in ranks[:-1] if rank != 0) > 1:
            escaped_token = escape_token(token)
            f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")

In [15]:
with open("../tokenizers/token_ranks_all.tsv", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        if sum(1 for rank in ranks[:-1] if rank != 0) == 4:
            escaped_token = escape_token(token)
            f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")

In [16]:
with open("../tokenizers/token_ranks_all_no_se.tsv", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        if sum(1 for rank in ranks[:-2] if rank != 0) == 3:
            escaped_token = escape_token(token)
            f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")

In [18]:
with open("../tokenizers/token_ranks_not_uniq.n3.tsv", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        if sum(1 for rank in ranks[:-1] if rank != 0) > 1:
            escaped_token = escape_token(token)
            if len(escaped_token) > 3 and not "<" in escaped_token:
                f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")

In [21]:
with open("../tokenizers/token_ranks_all.n3.tsv", "w") as f:
    header = "\t".join(vocabs.keys())
    f.write(header + "\n")
    for token, ranks in token_to_ranks.items():
        if sum(1 for rank in ranks[:-1] if rank != 0) == 4:
            escaped_token = escape_token(token)
            if len(escaped_token) > 3 and not "<" in escaped_token:
                escaped_token = escaped_token.replace(" ", "_")
                f.write(escaped_token + "\t" + "\t".join(map(str, ranks)) + "\n")