In [1]:
BERT_BASE_DIR = '/Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k'

In [2]:
import torch
from transformers import BertForMaskedLM
from tokenization import MecabBertTokenizer

I1001 22:27:35.863064 4618630592 file_utils.py:39] PyTorch version 1.2.0 available.
I1001 22:27:36.331261 4618630592 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


In [3]:
tokenizer = MecabBertTokenizer(vocab_file=f'{BERT_BASE_DIR}/vocab.txt')

In [4]:
text = '今年の冬は友達と北海道に行きました。'

In [5]:
tokens = ['[CLS]'] + tokenizer.tokenize(text)

In [6]:
tokens

['[CLS]', '今年', 'の', '冬', 'は', '友達', 'と', '北海道', 'に', '行き', 'まし', 'た', '。']

In [7]:
tokens[7] = '[MASK]'

In [8]:
tokens

['[CLS]', '今年', 'の', '冬', 'は', '友達', 'と', '[MASK]', 'に', '行き', 'まし', 'た', '。']

In [9]:
token_ids = tokenizer.convert_tokens_to_ids(tokens)

In [10]:
token_ids

[2, 18337, 5, 2558, 9, 11680, 13, 4, 7, 2563, 3926, 10, 8]

In [11]:
token_ids = torch.tensor([token_ids])

In [12]:
token_ids

tensor([[    2, 18337,     5,  2558,     9, 11680,    13,     4,     7,  2563,
          3926,    10,     8]])

In [13]:
model = BertForMaskedLM.from_pretrained(BERT_BASE_DIR)

I1001 22:27:42.803631 4618630592 configuration_utils.py:148] loading configuration file /Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/config.json
I1001 22:27:42.805089 4618630592 configuration_utils.py:168] Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 32000
}

I1001 22:27:42.806373 4618630592 modeling_utils.py:334] loading weights file /Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/pytorch_model.bin
I1001 22:27:45.143396 4618630592 modeling_utils.py:408] Weights fr

In [14]:
predictions, = model(token_ids)

In [15]:
_, top10_pred_ids = torch.topk(predictions, k=10, dim=2)

In [16]:
top10_pred_ids

tensor([[[    6,     8,  1191, 11680,    13, 14142,    10, 12944,  4733,  6115],
         [18337, 18822,  2558,  1052,  1331,  1460,  4960,  1322,    19,  7246],
         [    5,     6,    28,     9,    52,    13,    40,    60,  1191,    18],
         [ 2558,  1460,  1383,  1331,  1158,    72,  4587,  7885,  8211,    51],
         [    9,    28,     6,     7,     5,    12, 14966,  1191, 10590,    40],
         [11680,  8080,  3681,  3713,  6296,  2286, 14066, 15884,  1052,  2569],
         [   13,    12,     7, 25350,   996,     5,    14,  4338,    11, 21693],
         [ 8135,  4338,   294,  6128,  1767,   399,   466,  1743,  2711,   292],
         [    7,   118,    16,    12,    14, 28444,    11,  6115,    15,     6],
         [ 2563,   521, 19874,    21,  4154, 11438,  1676,  1258,    15,  1220],
         [ 3926, 13259,  2554,  6771,    15,  3959, 12727,   303,  1158,  3061],
         [   10,    16,  3203, 28445,   807,  3287, 28480,  7428,    15, 17167],
         [    8,    10,   14

In [17]:
for correct_id, pred_ids in zip(token_ids[0], top10_pred_ids[0]):
    correct_token = tokenizer.convert_ids_to_tokens([correct_id.item()])
    pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids.tolist())
    print(correct_token, pred_tokens)


['[CLS]'] ['、', '。', '初めて', '友達', 'と', 'すご', 'た', 'たくさん', 'そんな', 'って']
['今年'] ['今年', '昨年', '冬', '自分', '秋', '夏', '最近', '私', '年', '今回']
['の'] ['の', '、', 'も', 'は', '一', 'と', 'から', '-', '初めて', '1']
['冬'] ['冬', '夏', '春', '秋', '始め', '時', 'オフ', '冬季', 'クリスマス', '中']
['は'] ['は', 'も', '、', 'に', 'の', 'で', 'いっぱい', '初めて', 'ずっと', 'から']
['友達'] ['友達', 'みんな', '友人', '仲間', '僕', '家族', 'いろいろ', '色々', '自分', '君']
['と'] ['と', 'で', 'に', 'ちゃんと', 'と共に', 'の', 'が', '一緒', 'を', 'だって']
['[MASK]'] ['遊び', '一緒', 'アメリカ', 'ハワイ', '韓国', '東京', '学校', '北海道', '沖縄', '海']
['に'] ['に', 'へ', 'て', 'で', 'が', '##に', 'を', 'って', 'し', '、']
['行き'] ['行き', '行っ', '行け', 'い', 'いき', '帰り', '行い', '来', 'し', '入り']
['まし'] ['まし', 'でし', 'ます', 'ませ', 'し', 'っ', 'だし', 'だっ', '始め', 'です']
['た'] ['た', 'て', 'ちゃ', '##た', 'たり', 'たら', '##こ', 'てる', 'し', 'たき']
['。'] ['。', 'た', 'という', 'ので', '!', '、', 'けど', 'です', 'ね', 'が']
