In [83]:
import os
import regex as re

In [84]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [85]:
from cs336_basics.myBPE import my_train_bpe
corpus_file = "/root/workspace/cs336/assignment1/tests/fixtures/corpus.en"
special_tokens_train = ["<|endoftext|>"]
vocab,merges = my_train_bpe(corpus_file,500,special_tokens_train,PAT)

In [89]:
class myTokenizer:
    def __init__(self, vocab, merges, pattern,special_tokens=None) -> None:
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens #str
        self.pattern = pattern
        self.tk_to_id = {token:id for id,token in vocab.items()}  #bytes
        self.byte_id = [self.tk_to_id[bytes([i])] for i in range(256)]
        #initial for special tokens
        if special_tokens:
            id_allo = max(id for id in self.vocab) + 1
            for tk_str in special_tokens:
                tk = tk_str.encode("utf-8")
                if tk not in self.tk_to_id:
                    self.tk_to_id[tk] = id_allo
                    self.vocab[id_allo] = tk 
                    id_allo += 1
                    
        if self.special_tokens:
            toks = sorted(set(self.special_tokens), key=lambda s: (-len(s), s))
            special_pat = "|".join(map(re.escape, toks))
        else:
            special_pat = None
        self.special_re = re.compile(f"({special_pat})") if special_pat else None
        
    def encode(self, text):
        chunk_set = [s for s in self.special_re.split(text) if s] if self.special_re else [text]
        text_seq = []
        token_seq = []
        for small_chunk in chunk_set:
            if self.special_tokens and (small_chunk in self.special_tokens):
                text_seq += [small_chunk]
            else:
                text_seq += re.findall(self.pattern, small_chunk)
        for small_text in text_seq:
            if self.special_tokens and small_text in self.special_tokens:
                token_seq += [self.tk_to_id[small_text.encode("utf-8")]]
                continue
            btext_list = [self.byte_id[x] for x in small_text.encode("utf-8")]
            for m1,m2 in self.merges:
                _ = 0
                btext_new = []
                while(_ < len(btext_list)):
                    if _ + 1 < len(btext_list) and btext_list[_] == self.tk_to_id[m1] and btext_list[_+1] == self.tk_to_id[m2]:
                        btext_new.append(self.tk_to_id[m1+m2])
                        _ += 2
                    else:
                        btext_new.append(btext_list[_])
                        _ += 1
                btext_list = btext_new
            token_seq += btext_list  
        return token_seq
    def encode_iterable(self, iterable):
        for text in iterable:
            ids = self.encode(text)
            for id in ids:
                yield id
    def decode(self, ids:list[int]):
        res = b''
        for id in ids:
            if type(id) == list:
               pass
            else: 
                res += self.vocab[id]
        res = res.decode("utf-8",errors="ignore")
        return res 
        

In [90]:
special_tokens = ["<|endoftext|>","<|endoftext|><|endoftext|>"]
#special_tokens = None
tokenizer = myTokenizer(vocab,merges,PAT,special_tokens)

str

In [91]:
test = "HelNonelo, <|endoftext|><|endoftext|>how None<|endoftext|>are you?<|endoftext|>"
ids = tokenizer.encode(test)
    # assert ids == reference_ids
tokenized_string = [tokenizer.decode([x]) for x in ids]
tokenized_string


['H',
 'el',
 'N',
 'on',
 'el',
 'o',
 ',',
 ' ',
 '<|endoftext|><|endoftext|>',
 'h',
 'ow',
 ' ',
 'N',
 'on',
 'e',
 '<|endoftext|>',
 'are',
 ' you',
 '?',
 '<|endoftext|>']

In [92]:
tokenizer.decode(tokenizer.encode(test))==test

True

In [9]:
a = [1,2,3]
a[0:]

[1, 2, 3]