-
Notifications
You must be signed in to change notification settings - Fork 987
Closed
Description
tokenizers.__version__: 0.9.4
transformers.__version__: 4.1.1
I'm trying to get the same tokenization from the tokenizers package and the transformers package and am running into issues.
The issues are around the Roberta post_processing (adding <s> to beginning and </s> to end) as well as the white space before <mask> tokens.
A full minimal example and its output is below,
"""
wget https://huggingface.co/roberta-base/resolve/main/merges.txt
wget https://huggingface.co/roberta-base/resolve/main/vocab.json
"""
import tokenizers
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import RobertaProcessing
from tokenizers import AddedToken
import transformers
from transformers import RobertaTokenizer
from transformers import RobertaTokenizerFast
SPECIAL_TOKENS = []
SPECIAL_TOKENS.append(AddedToken("<s>"))
SPECIAL_TOKENS.append(AddedToken("<pad>"))
SPECIAL_TOKENS.append(AddedToken("</s>"))
SPECIAL_TOKENS.append(AddedToken("<unk>"))
SPECIAL_TOKENS.append(AddedToken("<mask>", lstrip=False))
print()
print("tokenizers.__version__: ", tokenizers.__version__)
print("transformers.__version__: ", transformers.__version__)
print()
model = "roberta-base"
trf = RobertaTokenizer.from_pretrained(model)
print("transformers-RobertaTokenizer")
print(trf)
print()
trf_fast = RobertaTokenizerFast.from_pretrained(model)
print("transformers-RobertaTokenizerFast")
print(trf_fast)
print()
tkn = ByteLevelBPETokenizer("vocab.json", "merges.txt")
tkn.add_special_tokens(SPECIAL_TOKENS)
tkn.post_processor = RobertaProcessing(
sep=("</s>", tkn.token_to_id("</s>")),
cls=("<s>", tkn.token_to_id("<s>")),
)
print("tokenizers-ByteLevelBPETokenizer")
print(tkn)
print()
text = "The <mask> and <pad> and <unk> tokens."
print("transformers-RobertaTokenizer")
print(trf.convert_ids_to_tokens(trf.encode(text)))
print()
print("transformers-RobertaTokenizerFast")
print(trf_fast.convert_ids_to_tokens(trf_fast.encode(text)))
print()
print("tokenizers-ByteLevelBPETokenizer")
print(tkn.encode(text).tokens)
print()tokenizers.__version__: 0.9.4
transformers.__version__: 4.1.1
transformers-RobertaTokenizer
PreTrainedTokenizer(name_or_path='roberta-base', vocab_size=50265, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)})
transformers-RobertaTokenizerFast
PreTrainedTokenizerFast(name_or_path='roberta-base', vocab_size=50265, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
tokenizers-ByteLevelBPETokenizer
Tokenizer(vocabulary_size=50265, model=ByteLevelBPE, add_prefix_space=False, lowercase=False, dropout=None, unicode_normalizer=None, continuing_subword_prefix=None, end_of_word_suffix=None, trim_offsets=False)
transformers-RobertaTokenizer
['<s>', 'The', '<mask>', 'Ġand', 'Ġ', '<pad>', 'Ġand', 'Ġ', '<unk>', 'Ġtokens', '.', '</s>']
transformers-RobertaTokenizerFast
['<s>', 'The', '<mask>', 'Ġand', 'Ġ', '<pad>', 'Ġand', 'Ġ', '<unk>', 'Ġtokens', '.', '</s>']
tokenizers-ByteLevelBPETokenizer
['The', 'Ġ', '<mask>', 'Ġand', 'Ġ', '<pad>', 'Ġand', 'Ġ', '<unk>', 'Ġtokens', '.']orestisfl
Metadata
Metadata
Assignees
Labels
No labels