In [2]:
import torch
from transformers import BertJapaneseTokenizer, BertForMaskedLM

In [3]:
model_name_or_path = "cl-tohoku/bert-base-japanese-v2"

In [6]:
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)

In [7]:
model = BertForMaskedLM.from_pretrained(model_name_or_path)

Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████| 447M/447M [00:39<00:00, 11.4MB/s]
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-v2 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
input_ids = tokenizer.encode(f"青葉山で{tokenizer.mask_token}の研究をしています。", return_tensors="pt")

In [9]:
print(input_ids)

tensor([[    2, 21479,  2077,   889,     4,   896, 11261,   932,   873,   888,
           854, 12343,   829,     3]])


In [10]:
print(tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))

['[CLS]', '青葉', '山', 'で', '[MASK]', 'の', '研究', 'を', 'し', 'て', 'い', 'ます', '。', '[SEP]']


In [11]:
masked_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0].tolist()
print(masked_index)

4


In [12]:
result = model(input_ids)
pred_ids = result[0][:, masked_index].topk(5).indices.tolist()[0]
for pred_id in pred_ids:
    output_ids = input_ids.tolist()[0]
    output_ids[masked_index] = pred_id
    print(tokenizer.decode(output_ids))


[CLS] 青葉 山 で 植物 の 研究 を し て い ます 。 [SEP]
[CLS] 青葉 山 で 鳥類 の 研究 を し て い ます 。 [SEP]
[CLS] 青葉 山 で 野鳥 の 研究 を し て い ます 。 [SEP]
[CLS] 青葉 山 で 恐竜 の 研究 を し て い ます 。 [SEP]
[CLS] 青葉 山 で 昆虫 の 研究 を し て い ます 。 [SEP]
