In [1]:
import os
import json
from pathlib import Path
from attrs import define, Factory, asdict

from typing import (
    Union,
    Sequence,
    Iterable,
    Optional,
)

from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors

## Constants

In [2]:
# Data paths

DATA_PATH = Path(".")/"input"
TITLE_DATASET = DATA_PATH/"all_titles.txt"

# Tokenizer paths

TOKENIZER_PATH = Path(".")/"model"/"tokenizer"
TOKENIZER_TRAINING_CONFIG = TOKENIZER_PATH/"training_config.json"
PRETRAINED_TOKENIZER_DIR = TOKENIZER_PATH/"pretrained_tokenizer"

# Tokenizer initial parameters

VOCAB_SIZE = 1528  # Number of base characters (see EDA) + four special tokens + 1000 merges
UNK = "[UNK]"
CLS = "[CLS]"
SEP = "[SEP]"
PAD = "[PAD]"
SPECIAL_TOKENS = [PAD, CLS, SEP, UNK]
CONTINUING_SUBWORD_PREFIX = "##"
PADDING_SIDE = "right"

## Tokenizer Training

#### Tokenizer Training Config

In [3]:
@define(kw_only=True)
class TokenizerTrainingConfig:
    dataset_path: Union[str, os.PathLike] = TITLE_DATASET
    output_dir: Union[str, os.PathLike] = PRETRAINED_TOKENIZER_DIR
    vocab_size: int = VOCAB_SIZE
    special_tokens: Sequence[str] = Factory(
        lambda: SPECIAL_TOKENS
    )
    unk_token: str = UNK
    cls_token: str = CLS
    sep_token: str = SEP
    pad_token: str = PAD
    continuing_subword_prefix: str = CONTINUING_SUBWORD_PREFIX
    padding_side: str = PADDING_SIDE
    
    @classmethod
    def from_json(cls, config_json: Union[str, os.PathLike]):
        if not os.path.exists(config_json):
            raise FileNotFoundError(f"Couldn't find {config_json}")
        with open(config_json, "r") as infile:
            config = json.load(infile)
            config["dataset_path"] = Path(config["dataset_path"])
            config["output_dir"] = Path(config["output_dir"])
            return cls(**config)
        
    def to_json(self, config_json: Union[str, os.PathLike]):
        config_dict = asdict(self)
        config_dict["dataset_path"] = str(self.dataset_path)
        config_dict["output_dir"] = str(self.output_dir)
        
        with open(config_json, "w") as outfile:
            json.dump(config_dict, outfile, ensure_ascii=False, indent=4)

In [4]:
if not os.path.exists(TOKENIZER_TRAINING_CONFIG):
    tokenizer_training_config = TokenizerTrainingConfig()
    tokenizer_training_config.to_json(TOKENIZER_TRAINING_CONFIG)
else:
    tokenizer_training_config = TokenizerTrainingConfig.from_json(TOKENIZER_TRAINING_CONFIG)

#### Tokenizer

In [5]:
def fit_tokenizer(config: TokenizerTrainingConfig):
    if not os.path.exists(config.dataset_path):
        raise FileNotFoundError(f"Couldn't find {config.dataset_path}")
    
    def load_data(data_file: Union[str, os.PathLike],
                  encoding: Optional[str] = "utf-8") -> Iterable[str]:
        with open(data_file, "r", encoding=encoding) as infile:
            for line in infile:
                yield line.strip()
                
    tokenizer = Tokenizer(models.BPE(
        unk_token=config.unk_token,
        continuing_subword_prefix=config.continuing_subword_prefix
    ))
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
    tokenizer.decoder = decoders.ByteLevel()
    
    trainer = trainers.BpeTrainer(
        vocab_size=config.vocab_size,
        special_tokens=config.special_tokens,
        continuing_subword_prefix=config.continuing_subword_prefix
    )
    tokenizer.train_from_iterator(load_data(config.dataset_path), trainer=trainer)
    
    cls_token_id = tokenizer.token_to_id(config.cls_token)
    sep_token_id = tokenizer.token_to_id(config.sep_token)
    tokenizer.post_processor = processors.TemplateProcessing(
        single="[CLS] $A [SEP]",
        pair="[CLS] $A [SEP] $B:1 [SEP]:1",
        special_tokens=[
            (config.cls_token, cls_token_id),
            (config.sep_token, sep_token_id),
        ]
    )
    
    wrapped_tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        unk_token=config.unk_token,
        pad_token=config.pad_token,
        sep_token=config.sep_token,
        padding_side=config.padding_side,
    )
    wrapped_tokenizer.save_pretrained(config.output_dir)

In [6]:
fit_tokenizer(tokenizer_training_config)

#### Basic Tokenization

In [7]:
tokenizer = PreTrainedTokenizerFast.from_pretrained(PRETRAINED_TOKENIZER_DIR)

In [8]:
sample_text = "Гадаем на свадебных букетах: когда ты выйдешь замуж? 💐"
token_ids = tokenizer.encode(sample_text, add_special_tokens=True)
print(", ".join(list(map(str, token_ids))))

1, 488, 346, 1233, 339, 540, 1490, 333, 518, 25, 921, 325, 1287, 1104, 26, 578, 164, 2


In [9]:
restored_text = tokenizer.decode(token_ids, skip_special_tokens=True)
print(restored_text)

 Гадаем на свад##еб##ных бук##ет##ах: когда ты выйдешь замуж? �##�


In [12]:
for _id in token_ids:
    print(f"{tokenizer.decode([_id], skip_special_tokens=False).strip()}\t{_id}")

[CLS]	1
Гадаем	488
на	346
свад	1233
##еб	339
##ных	540
бук	1490
##ет	333
##ах	518
:	25
когда	921
ты	325
выйдешь	1287
замуж	1104
?	26
�	578
##�	164
[SEP]	2
