In [63]:
from datasets import list_datasets, load_dataset
import tokenizers.normalizers as tn
import tokenizers.pre_tokenizers as tp
import tokenizers.models as tm
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordPieceTrainer
from tokenizers.decoders import WordPiece
from tokenizers import Tokenizer

In [11]:
wiki = load_dataset('wikitext', 'wikitext-2-raw-v1')

Reusing dataset wikitext (/home/xevaquor/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)


In [12]:
with open('wiki.raw', 'wt') as f:
    for l in wiki['train']['text']:
        f.write(l)

In [66]:
vocab_size = 10_000

model = tm.WordPiece()

tokenizer = Tokenizer(model)
tokenizer.normalizer = tn.Sequence([
    tn.NFD(), # unicode normalization
    tn.StripAccents(), # strip accents
    tn.Lowercase(),
    tn.Strip() # strip starting and trailing whitespaces
])
tokenizer.pre_tokenizer = tp.Whitespace() # pretokenize on spaces
# add starting [CLS] token
tokenizer.post_processor = TemplateProcessing(single="[CLS] $A", special_tokens=[('[CLS]', 1)])

tokenizer.decoder = WordPiece()


trainer = WordPieceTrainer(
    vocab_size=vocab_size,
    special_tokens=["[UNK]", "[CLS]", "[PAD]", "[MASK]"]    
)


tokenizer.train(trainer, ['wiki.raw'])


In [68]:
tokenizer.save('tokenizer.json', pretty=True)

In [69]:
tok = Tokenizer.from_file('tokenizer.json')

In [75]:
seq_len = 10
tok.enable_padding(pad_token='[PAD]', length=seq_len)
tok.enable_truncation(max_length=seq_len)
enc = tok.encode('I like pizza')
enc

Encoding(num_tokens=10, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [77]:
enc.tokens

['[CLS]', 'i', 'like', 'p', '##iz', '##za', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

In [90]:
enc.special_tokens_mask

[1, 0, 0, 0, 0, 0, 1, 1, 1, 1]

In [91]:
enc.attention_mask

[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]

In [88]:
ids = enc.ids
ids[4] = tok.token_to_id('[MASK]')
ids

[1, 50, 2026, 57, 3, 9861, 0, 0, 0, 0]

In [89]:
tok.decode(ids)

'i like pza'