In [1]:
reader = open("voa_fa_2003-2008_orig.txt", "r", encoding="utf-8")
text = reader.read()
reader.close()
# delete lines starting with #
text = "\n".join([line for line in text.split("\n") if not line.startswith("#")])
# delete lines with only \n
text = "\n".join([line for line in text.split("\n") if line.strip() != ""])
# save to file
with open("voa_fa_2003-2008_clean.txt", "w", encoding="utf-8") as writer:
    writer.write(text)

In [2]:
reader = open("voa_fa_2003-2008_clean.txt", "r", encoding="utf-8")
text = reader.read()
reader.close()
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))

In [3]:
reader = open("text.txt", "r", encoding="utf-8")
text = reader.read()
reader.close()
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))

In [4]:
tokens[:12]

[216, 179, 216, 177, 216, 178, 217, 133, 219, 140, 217, 134]

In [5]:
# decode to persian to byte string

text = bytes(tokens[:34])
text = text.decode("utf-8", errors="replace")
print(text)

سرزمین ایران با ده 


In [6]:

def count_pairs(tokens: list[int]) -> dict[tuple, int]:
    """
    Return the count of pairs in tokens.
    """
    pairs = {}
    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i + 1])
        pairs[pair] = pairs.get(pair, 0) + 1
    return pairs

# pairs = count_pairs(tokens)

In [7]:
def find_min_pair(pairs: dict, merges: dict) -> tuple:
    """
    Find the pair from a dictionary of pairs that has the smallest value.
    
    Args:
        pairs (dict): A dictionary where keys are pairs and values are their associated values.
    
    Returns:
        The pair with the smallest value.
    """
    # Initialize variables to track the minimum pair and its associated value
    min_pair = None
    min_value = float("inf")
    
    # Iterate through each pair in the pairs dictionary
    for pair in pairs:
        # Get the value from the merges dictionary, defaulting to infinity if not found
        value = merges.get(pair, float("inf"))
        
        # Check if the current value is less than the minimum value found so far
        if value < min_value:
            min_value = value
            min_pair = pair
            
    return min_pair


def find_max_pair(pairs: dict) -> tuple:
    """
    Find the pair from a dictionary of pairs that has the largest value.
    
    Args:
        pairs (dict): A dictionary where keys are pairs and values are their associated values.
    
    Returns:
        The pair with the largest value.
    """
    # Start with the first pair as the current best
    best_pair = None
    best_value = float("-inf")  # Start with the smallest possible value
    
    # Check each pair in the dictionary
    for pair, value in pairs.items():
        # If this pair has a larger value than the current best, update
        if value > best_value:
            best_value = value
            best_pair = pair
    
    return best_pair

In [9]:
def merge_top_pair(tokens: list[int], pair: tuple[int, int], new_byte: int) -> list[int]:
    """
    Merge the top pair with new_byte
    """
    i = 0
    new_tokens = []
    while i < len(tokens):
        if tokens[i] == pair[0] and tokens[i + 1] == pair[1] and i < len(tokens) - 1:
            new_tokens.append(new_byte)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

vocabulary_size = 260
num_merges = vocabulary_size - 256
# copy the token

tokens_copy = tokens.copy()

merges = {}
for i in range(num_merges):
    pairs = count_pairs(tokens_copy)
    # max_pair = find_max_pair(pairs)
    max_pair = max(pairs, key=pairs.get)
    print(f"merging {max_pair} into a new token {256 + i}")
    tokens_copy = merge_top_pair(tokens_copy, max_pair, 256 + i)
    merges[max_pair] = 256 + i

merging (32, 216) into a new token 256
merging (216, 167) into a new token 257
merging (219, 140) into a new token 258
merging (216, 177) into a new token 259


In [10]:
print("tokens length:", len(tokens))
print("new tokens length:", len(tokens_copy))
print(f"compression ratio: {len(tokens) / len(tokens_copy):.2f}X")

tokens length: 773
new tokens length: 628
compression ratio: 1.23X


In [11]:
# decoding

def decode(tokens: list[int]) -> str:
    vocabulary = {i:bytes([i]) for i in range(256)}
    for k, v in merges.items():
        vocabulary[v] = vocabulary[k[0]] + vocabulary[k[1]]

    text_byte = b"".join(vocabulary[i] for i in tokens)
    text = text_byte.decode("utf-8", errors="replace")
    return text

print(decode(tokens_copy))

سرزمین ایران با ده هزار سال تاریخ و تمدن، میزبان تمدن‌های کهن متعددی چون ایلام در هزارهٔ چهارم پیش از میلاد بوده است. در سدهٔ هفتم پ. م، پادشاهی ماد بخش‌های قابل‌توجهی از فلات ایران را یکپارچه کرد. در سدهٔ ششم پ. م، شاهنشاهی هخامنشی به‌دست کوروش بزرگ بنیان نهاده شد تا ایران یکی از بزرگ‌ترین امپراتوری‌های تاریخ را تشکیل دهد. در سدهٔ چهارم پ. م، اسکندر مقدونی این قلمرو را تسخیر کرد و ایران به بخشی از سرزمین‌های هلنی تبدیل شد.


In [None]:
# encoding

def encode(text: str, merges: dict) -> list[int]:
    tokens = text.encode("utf-8")
    tokens = list(map(int, tokens))
    while len(tokens) >= 2:
        pairs = count_pairs(tokens)
        min_pair = find_min_pair(pairs)
        if min_pair in merges:
            tokens = merge_top_pair(tokens, min_pair, merges[min_pair])
        else:
            break
    return tokens

tokens = encode(text, merges)

[]


In [None]:
print(decode(encode("مادرتو دیدم", merges)))

مادرتو دیدم
