In [None]:
import os
import regex as re
import heapq

In [2]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [3]:
from cs336_basics.myBPE import my_train_bpe
corpus_file = "/root/workspace/cs336/assignment1/tests/fixtures/corpus.en"
special_tokens_train = ["<|endoftext|>"]
vocab,merges = my_train_bpe(corpus_file,500,special_tokens_train,PAT)

In [None]:
class myTokenizer:
    def __init__(self, vocab, merges, pattern,special_tokens=None) -> None:
        self.vocab = vocab
        self.special_tokens = special_tokens #str
        self.tk_to_id = {token:id for id,token in vocab.items()}  #bytes
        self.byte_id = [self.tk_to_id[bytes([i])] for i in range(256)]
        self.merges_rank = {} #{(id,id):rank}
        self.merges_id = []
        
        for i,(m1,m2) in enumerate(merges):
            id1 = self.tk_to_id[m1]
            id2 = self.tk_to_id[m2]
            idm = self.tk_to_id[m1+m2]
            self.merges_rank[(id1,id2)] = i
            self.merges_id.append((id1,id2,idm))
        #initial for special tokens
        if special_tokens:
            id_allo = max(id for id in self.vocab) + 1
            for tk_str in special_tokens:
                tk = tk_str.encode("utf-8")
                if tk not in self.tk_to_id:
                    self.tk_to_id[tk] = id_allo
                    self.vocab[id_allo] = tk 
                    id_allo += 1
                    
        if self.special_tokens:
            toks = sorted(set(self.special_tokens), key=lambda s: (-len(s), s))
            special_pat = "|".join(map(re.escape, toks))
        else:
            special_pat = None
        self.special_re = re.compile(f"({special_pat})") if special_pat else None
        self.sentence_re = re.compile(pattern)
        
    def encode(self, text):
        chunk_set = [s for s in self.special_re.split(text) if s] if self.special_re else [text]
        text_seq = []
        token_seq = []
        for small_chunk in chunk_set:
            if self.special_tokens and (small_chunk in self.special_tokens):
                text_seq.append(small_chunk)
            else:
                text_seq.extend(self.sentence_re.findall(small_chunk))
        for small_text in text_seq:
            if self.special_tokens and small_text in self.special_tokens:
                token_seq.append(self.tk_to_id[small_text.encode("utf-8")])
                continue
            btext_list = [self.byte_id[x] for x in small_text.encode("utf-8")]
            pair_exist = set(zip(btext_list[:-1],btext_list[1:]))
            
            pair_heap = []
            for pair in pair_exist:
                if self.merges_rank.get(pair):
                    heapq.heappush(pair_heap,(self.merges_rank[pair],pair))

            pair_exist = set(zip(btext_list[:-1],btext_list[1:]))
            token_seq.extend(btext_list)  
        return token_seq
    def encode_iterable(self, iterable):
        for text in iterable:
            ids = self.encode(text)
            for id in ids:
                yield id
    def decode(self, ids:list[int]):
        res = b"".join(self.vocab[i] for i in ids)
        return res.decode("utf-8", errors="ignore")
        

In [5]:
special_tokens = ["<|endoftext|>","<|endoftext|><|endoftext|>"]
#special_tokens = None
tokenizer = myTokenizer(vocab,merges,PAT,special_tokens)

In [6]:
test = "HelNonelo, <|endoftext|><|endoftext|>how None<|endoftext|>are you?<|endoftext|>"
ids = tokenizer.encode(test)
    # assert ids == reference_ids
tokenized_string = [tokenizer.decode([x]) for x in ids]
tokenized_string


['H',
 'el',
 'N',
 'on',
 'el',
 'o',
 ',',
 ' ',
 '<|endoftext|><|endoftext|>',
 'h',
 'ow',
 ' ',
 'N',
 'on',
 'e',
 '<|endoftext|>',
 'are',
 ' you',
 '?',
 '<|endoftext|>']

In [7]:
tokenizer.decode(tokenizer.encode(test))==test

True

In [8]:
a = [1,2,3]
a[0:]

[1, 2, 3]