In [None]:
from src.requirements import *
from src.audio_handler import *

In [None]:
class Tokenizer:
    def __init__(self, vocab=None):
        self.blank_token = "<blank>"
        self.blank_id = 0
        
        if vocab is not None:
            self.token_to_id = vocab["tokens"]
            self.id_to_token = {int(k): v for k, v in vocab["ids"].items()}
            self.vocab_size = vocab["size"]

            return
            
        self.token_to_id = {}
        self.id_to_token = {}

    # Normalization
    def normalize(self, text: str) -> str:
        return unicodedata.normalize("NFD", text)

    def denormalize(self, text: str) -> str:
        return unicodedata.normalize("NFC", text)

    # Vocab
    def build_vocab(self, texts):
        counter = Counter()

        for text in texts:
            text = self.normalize(text)
            for ch in text:
                counter[ch] += 1

        # <blank> : id[0]
        # " " : id[1]
        self.token_to_id = {self.blank_token: self.blank_id}
        self.id_to_token = {self.blank_id: self.blank_token}

        next_id = 1
        for ch, _ in counter.most_common():
            self.token_to_id[ch] = next_id
            self.id_to_token[next_id] = ch
            next_id += 1

        self.vocab_size = next_id

    # Encoding / Decoding
    def encode(self, text: str):
        text = self.normalize(text)
        ids = []

        for ch in text:
            if ch not in self.token_to_id:
                raise ValueError(f"Unknown character: {repr(ch)}")
            ids.append(self.token_to_id[ch])

        return ids

    def decode(self, ids):
        chars = []

        for i in ids:
            if i == self.blank_id:
                continue
            chars.append(self.id_to_token[i])

        text = "".join(chars)
        return self.denormalize(text)

    def save(self, path):
        data = {"tokens" : self.token_to_id, "ids" : self.id_to_token, "size" : self.vocab_size}
        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

    @staticmethod
    def load(path):
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        return Tokenizer(vocab=data)

    def __len__(self):
        return self.vocab_size

In [None]:
path = os.path.join("data", "text")
text = load_text(path)

In [None]:
token_path = os.path.join("data", "tokenizer.json")
if not os.path.exists(token_path):
    tokenizer = Tokenizer()
    tokenizer.build_vocab(text)
    tokenizer.save(token_path)
else:
    tokenizer = Tokenizer.load(token_path)
    
print("Vocab size:", len(tokenizer))

In [None]:
# tokenizer = Tokenizer()
# tokenizer.build_vocab(text)

# print("Vocab size:", len(tokenizer))

In [None]:
tokenizer.token_to_id