In [None]:
train_file = "data/jawiki/20210301/data/train.txt"
spm_model_dir = "output/spm"
tf_model_dir = "output/model"
vocab_size = 32000
input_sentence_size = 10000000
add_dummy_prefix = False

In [None]:
sep_token = "<sep>"
cls_token = "<cls>"
pad_token = "<pad>"
unk_token = "<unk>"
bos_token = "<s>"
eos_token = "</s>"

In [None]:
from pathlib import Path

spm_model_dir = Path(spm_model_dir)
spm_model_prefix = Path(spm_model_dir) / Path("sp")
spm_model_path = Path(spm_model_dir) / Path("sp.model")

In [None]:
train_args = dict(
    model_prefix=spm_model_prefix,
    vocab_size=vocab_size,
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    pad_piece=pad_token,
    unk_piece=unk_token,
    bos_piece=bos_token,
    eos_piece=eos_token,
    control_symbols=[cls_token, sep_token],
    input_sentence_size=input_sentence_size,
    shuffle_input_sentence=True,
    add_dummy_prefix=add_dummy_prefix,
)

## Train model

In [None]:
pip install sentencepiece==0.1.91

In [None]:
import sentencepiece as spm

In [None]:
if not spm_model_dir.exists():
    spm_model_dir.mkdir(parents=True)

In [None]:
spm.SentencePieceTrainer.train(input=train_file, **train_args)

## Convert to Transformers model

In [None]:
pip install transformers==4.3.3

In [None]:
import transformers

tokenizer = transformers.BertGenerationTokenizer(
    str(spm_model_path),
    bos_token=bos_token,
    eos_token=eos_token,
    cls_token=cls_token,
    sep_token=sep_token,
    pad_token=pad_token,
    unk_token=unk_token,
)

In [None]:
len(tokenizer)

In [None]:
tokenizer.save_pretrained(tf_model_dir)