In [34]:
from datasets import load_dataset
import tokenizers
import torch

### Load data

In [6]:

imdb_dataset = load_dataset("stanfordnlp/imdb")
split = imdb_dataset['train'].train_test_split(train_size=0.8, seed=42)
imdb_train_set, imdb_valid_set = split['train'], split['test']
imdb_test_set = imdb_dataset['test']


In [8]:
imdb_train_set[1]['text']

"'The Rookie' was a wonderful movie about the second chances life holds for us and also puts an emotional thought over the audience, making them realize that your dreams can come true. If you loved 'Remember the Titans', 'The Rookie' is the movie for you!! It's the feel good movie of the year and it is the perfect movie for all ages. 'The Rookie' hits a major home run!"

### Tokenizer


In [None]:
bpe_model = tokenizers.models.BPE(unk_token="<unk>")
bpe_tokenizer = tokenizers.Tokenizer(bpe_model)
"""分词前先用空格分词"""
# bpe_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
"""把所有空格替换成Ġ"""
bpe_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel()
special_tokens = ['<pad>', "<unk>"]
bpe_trainer = tokenizers.trainers.BpeTrainer(vocab_size=1000, special_tokens=special_tokens)
train_reviews = [review["text"].lower() for review in imdb_train_set]
bpe_tokenizer.train_from_iterator(iterator=train_reviews, trainer=bpe_trainer)









In [50]:
some_review = "what an awesome movie!"

bpe_encoding = bpe_tokenizer.encode(some_review)
bpe_encoding

Encoding(num_tokens=7, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [57]:
print(bpe_encoding.tokens)

print(bpe_encoding.ids)
print(bpe_tokenizer.decode(bpe_encoding.ids))
# bpe_tokenizer.get_vocab()

['Ġwhat', 'Ġan', 'Ġaw', 'es', 'ome', 'Ġmovie', '!']
[354, 216, 561, 148, 244, 232, 2]
Ġwhat Ġan Ġaw es ome Ġmovie !


In [61]:
def restore_from_blevel(bpe_encoding: tokenizers.Encoding, tokenizers: tokenizers.Tokenizer):
    decoded:str = tokenizers.decode(bpe_encoding.ids)
    decoded = decoded.replace(" ", "")
    decoded = decoded.replace("Ġ", " ")
    return decoded

restore_from_blevel(bpe_encoding, bpe_tokenizer)



' what an awesome movie!'

In [53]:
print(bpe_tokenizer.token_to_id("Ġwell"))
bpe_tokenizer.id_to_token(222)


433


'Ġfilm'

In [None]:
bpe_encoding.offsets #每个词 开头结尾的位置

[(0, 4), (5, 7), (8, 10), (10, 12), (12, 15), (16, 21), (21, 22)]

In [31]:
bpe_tokenizer.encode_batch(train_reviews[:3])

[Encoding(num_tokens=281, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=114, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]

In [32]:
bpe_tokenizer.enable_padding(pad_id=0, pad_token="<pad>")
bpe_tokenizer.enable_truncation(max_length=500)


In [33]:
bpe_tokenizer.encode_batch(train_reviews[:3])

[Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]

In [36]:
bpe_encodings = bpe_tokenizer.encode_batch(train_reviews[:3])
bpe_batch_ids = torch.tensor([encoding.ids for encoding in bpe_encodings])
bpe_batch_ids.shape

torch.Size([3, 285])

In [39]:
bpe_encodings

[Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=285, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]

In [45]:
"""attention mask： 1和0的矩阵 能知道哪个值是padding"""
attention_mask = torch.tensor([encoding.attention_mask for encoding in bpe_encodings])
lengths = attention_mask.sum(dim=-1)
lengths

tensor([281, 114, 285])