In [1]:
import torch
from pytorch_transformers import BertForMaskedLM, BertConfig
from tokenization import MecabBertTokenizer

In [2]:
tokenizer = MecabBertTokenizer(vocab_file='/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/vocab.txt')

In [3]:
text = '今年の夏は友達と北海道に行きました。'

In [4]:
tokens = ['[CLS]'] + tokenizer.tokenize(text)

In [5]:
tokens

['[CLS]', '今年', 'の', '夏', 'は', '友達', 'と', '北海道', 'に', '行き', 'まし', 'た', '。']

In [6]:
tokens[7] = '[MASK]'

In [7]:
tokens

['[CLS]', '今年', 'の', '夏', 'は', '友達', 'と', '[MASK]', 'に', '行き', 'まし', 'た', '。']

In [8]:
token_ids = tokenizer.convert_tokens_to_ids(tokens)

In [9]:
token_ids

[2, 21659, 5, 1431, 9, 13164, 13, 4, 7, 2630, 4110, 10, 8]

In [10]:
token_ids = torch.tensor([token_ids])

In [11]:
token_ids

tensor([[    2, 21659,     5,  1431,     9, 13164,    13,     4,     7,  2630,
          4110,    10,     8]])

In [12]:
config = BertConfig.from_json_file('/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/bert_config.json')

In [13]:
model = BertForMaskedLM(config)

In [14]:
model.load_state_dict(torch.load('/Users/m-suzuki/work/japanese-bert/jawiki-20190701/mecab-ipadic-bpe-32k/pytorch_model.bin'), strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['cls.seq_relationship.weight', 'cls.seq_relationship.bias'])

In [15]:
predictions, = model(token_ids)

In [16]:
_, top10_pred_ids = torch.topk(predictions, k=10, dim=2)

In [17]:
top10_pred_ids

tensor([[[    9,     6,     8,     5,     7,    14,    40,    13,    12,    73],
         [21659,  2442, 19791,  5249,  1431,  8936,  7030,  1485,  1353,   738],
         [    5,  1390, 28443,    71,     9,     7,     6,    75,    28,  1337],
         [ 1431,  2442,  1337, 16045,  1390,   120, 13164, 16832,   288,    51],
         [    9,     6,     7,     5,    40,    13,    11,    12, 13164,    28],
         [13164,  3582,  7791, 10216,  8328,  5738,  1431,  4429,  1703,  3634],
         [   13,    12,   705,     6,     5,     9,    14,  1004,  1763,  7791],
         [ 8085,  4470,   473,  3754,  2756,   298, 13502,  1839,  5452, 19455],
         [    7,    12,   119,    16,    40,    14,    13,    15,  1763, 11957],
         [ 2630,   522, 11957,    15, 20190,  1234,  3487, 12660,    12,  4288],
         [ 4110, 13222,  7025,    15, 11355,  2551,   307,    16, 17365,  2982],
         [   10,    16,    17,   183,    75,    81,    15,  4110,  1520,   203],
         [    8,    40,    1

In [18]:
for correct_id, pred_ids in zip(token_ids[0], top10_pred_ids[0]):
    correct_token = tokenizer.convert_ids_to_tokens([correct_id.item()])
    pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids.tolist())
    print(correct_token, pred_tokens)
    

['[CLS]'] ['は', '、', '。', 'の', 'に', 'が', 'から', 'と', 'で', 'お']
['今年'] ['今年', '冬', '昨年', '最近', '夏', '今度', '今回', '彼女', '私', '今']
['の'] ['の', '春', '##の', 'この', 'は', 'に', '、', 'だ', 'も', '秋']
['夏'] ['夏', '冬', '秋', '夏休み', '春', 'もの', '友達', 'なつ', '間', '中']
['は'] ['は', '、', 'に', 'の', 'から', 'と', 'を', 'で', '友達', 'も']
['友達'] ['友達', '友人', 'みんな', '同級生', '先輩', '皆', '夏', '妹', '友', '仲間']
['と'] ['と', 'で', 'だけ', '、', 'の', 'は', 'が', 'と共に', 'よく', 'みんな']
['[MASK]'] ['遊び', '一緒', '学校', '旅行', '食べ', '海', '会い', '公園', '過ごし', '買い物']
['に'] ['に', 'で', 'へ', 'て', 'から', 'が', 'と', 'し', 'よく', '帰り']
['行き'] ['行き', '行っ', '帰り', 'し', '行け', '来', '行く', '行か', 'で', 'いき']
['まし'] ['まし', 'でし', 'ませ', 'し', 'たかっ', 'ます', 'だっ', 'て', 'ましょ', 'です']
['た'] ['た', 'て', 'な', 'つ', 'だ', 'ない', 'し', 'まし', 'たい', 'う']
['。'] ['。', 'から', 'と', '、', '」', '!', 'ので', 'という', 'の', 'ね']
