In [1]:
from transformers import BertModel

In [2]:
model = BertModel.from_pretrained('bert-base-cased')
model.save_pretrained('./download/')

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
s = 'today is a good day to learn transformers'
tokenizer(s)

{'input_ids': [101, 2052, 1110, 170, 1363, 1285, 1106, 3858, 11303, 1468, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

了解tokenizer内部具体步骤

In [4]:
# tokenize()
s = 'today is a good day to learn transformers'
tokens = tokenizer.tokenize(s)
tokens

['today', 'is', 'a', 'good', 'day', 'to', 'learn', 'transform', '##ers']

In [5]:
# convert_token_to_ids()
ids = tokenizer.convert_tokens_to_ids(tokens)
ids

[2052, 1110, 170, 1363, 1285, 1106, 3858, 11303, 1468]

In [6]:
# decode
# 会把 ## 自动拼起来
print(tokenizer.decode([11303,1468]))
print(tokenizer.decode(ids)) 
print(tokenizer.decode([101, 2052, 1110, 170, 1363, 1285, 1106, 3858, 11303, 1468, 102]))

transformers
today is a good day to learn transformers
[CLS] today is a good day to learn transformers [SEP]


attention_mask 在处理多个序列时的作用

In [7]:
from pprint import pprint as print  # 这个pprint能让打印的格式更好看一点
from transformers import AutoModelForSequenceClassification, AutoTokenizer
checkpoint = 'distilbert-base-uncased-finetuned-sst-2-english'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

处理单个文本

In [8]:
s = 'Today is a nice day!'
inputs = tokenizer(s,return_tensors='pt')
print(tokenizer.decode([ 101, 3570, 1110,  170, 3505, 1285,  106,  102]))
print(inputs)

'[CLS] status ɑ [unused165] mike ض [unused101] [SEP]'
{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]]),
 'input_ids': tensor([[ 101, 2651, 2003, 1037, 3835, 2154,  999,  102]])}


In [9]:
model(inputs.input_ids).logits

tensor([[-4.3232,  4.6906]], grad_fn=<AddmmBackward0>)

处理多个文本

In [10]:
ss = ['Today is a nice day!',
      'But what about tomorrow? Im not sure.']
inputs = tokenizer(ss, padding=True, return_tensors='pt')
print(inputs)

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'input_ids': tensor([[  101,  2651,  2003,  1037,  3835,  2154,   999,   102,     0,     0,
             0],
        [  101,  2021,  2054,  2055,  4826,  1029, 10047,  2025,  2469,  1012,
           102]])}


In [11]:
model(inputs.input_ids).logits

tensor([[-4.1957,  4.5675],
        [ 3.9803, -3.2120]], grad_fn=<AddmmBackward0>)

因为在padding之后，第一个句子的encoding变了，多了很多0， 而self-attention会attend到所有的index的值，因此结果就变了。这时，就需要我们不仅仅是传入input_ids，还需要给出attention_mask，这样模型就会在attention的时候，不去attend被mask掉的部。

In [12]:
model(inputs.input_ids,inputs.attention_mask).logits

tensor([[-4.3232,  4.6906],
        [ 3.9803, -3.2120]], grad_fn=<AddmmBackward0>)

In [13]:
id2label = model.config.id2label

In [14]:
import torch
predictions = torch.nn.functional.softmax(model(inputs.input_ids,inputs.attention_mask).logits, dim=-1)  
predictions

tensor([[1.2170e-04, 9.9988e-01],
        [9.9925e-01, 7.5180e-04]], grad_fn=<SoftmaxBackward0>)

In [15]:
for i in torch.argmax(predictions, dim=-1):
    print(id2label[i.item()])

'POSITIVE'
'NEGATIVE'
