# BERTを使ったMasked-Language Modelling (MLM)
   
- BERT = Bidirectional Encoder Representations from Transformers
- Googleが作ったNLP技術．膨大なオンライン文書(wikipediaなど)を基に**事前学習**したモデル．「文章の空欄穴埋め」や「次の文章・単語の予測」タスクを大量に学習している．
- 文脈まで考慮した分散表現（文章や単語をベクトル化したもの）を生成することができる．
- Transforms word into vector. 
- "Transformer" : text version of LSDM. Nodes do not just incorporate preceding nodes but earlier high-value nodes as well.  


In [1]:
from transformers import BertTokenizer, BertForMaskedLM # BERT
import torch # PyTorchを使う
import numpy as np

In [2]:
# bert-base-uncased = 事前学習済み英語用BERT.日本語版もある．初回の読み込みは結構時間がかかる．

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 
model     = BertForMaskedLM.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# Load pickled files

import pickle

urls = ['https://www3.nhk.or.jp/nhkworld/en/news/20211025_13/',
        'https://www3.nhk.or.jp/nhkworld/en/news/20211025_08/',
        'https://www3.nhk.or.jp/nhkworld/en/news/20211025_04/',
        'https://www3.nhk.or.jp/nhkworld/en/news/20211024_11/',
        'https://www3.nhk.or.jp/nhkworld/en/news/20211024_13/',
        'https://www3.nhk.or.jp/nhkworld/en/news/20211022_32/']

# Articles: URLの最後の数字部分を取り出す

articles = [urls[0].split('/')[len(urls[0].split('/'))-2]]
for i in range(1,len(urls)):
    articles.append(urls[i].split('/')[len(urls[i].split('/'))-2])
    
data = {}
for i, c in enumerate(articles):
    with open("articles/" + c + ".txt", "rb") as file:
        data[c] = pickle.load(file)

In [4]:
# text
text = data['20211024_13'][1]
text

'The competition has been running for about 40 years but was cancelled last year due to the pandemic. Sumida is home to the famous Ryogoku Kokugikan, considered the spiritual heart of the sumo world.'

In [5]:
# トークン化

inputs = tokenizer.encode(text, return_tensors = 'pt') # 'pt' = PyTorchで使える形式にしておく
inputs # 最初の101と最後の102は形式的な値．1012はピリオド．Sumoは28193などなど．

tensor([[  101,  1996,  2971,  2038,  2042,  2770,  2005,  2055,  2871,  2086,
          2021,  2001,  8014,  2197,  2095,  2349,  2000,  1996,  6090,  3207,
          7712,  1012,  7680,  8524,  2003,  2188,  2000,  1996,  3297, 29431,
         22844,  5283, 12849,  5283,  5856,  9126,  1010,  2641,  1996,  6259,
          2540,  1997,  1996, 28193,  2088,  1012,   102]])

In [6]:
# "sumo"を予測できるか試してみる.

import re

text = re.sub('sumo', '[MASK]', text)
print(text)
inputs = tokenizer.encode(text, return_tensors = 'pt')
inputs # MASKした単語は103

The competition has been running for about 40 years but was cancelled last year due to the pandemic. Sumida is home to the famous Ryogoku Kokugikan, considered the spiritual heart of the [MASK] world.


tensor([[  101,  1996,  2971,  2038,  2042,  2770,  2005,  2055,  2871,  2086,
          2021,  2001,  8014,  2197,  2095,  2349,  2000,  1996,  6090,  3207,
          7712,  1012,  7680,  8524,  2003,  2188,  2000,  1996,  3297, 29431,
         22844,  5283, 12849,  5283,  5856,  9126,  1010,  2641,  1996,  6259,
          2540,  1997,  1996,   103,  2088,  1012,   102]])

In [7]:
# BERTに入力し，[MASK]にどの単語が最も入りやすいかを表すscoreを得る

with torch.no_grad():
    output = model(inputs)
    scores = output.logits

In [8]:
# 予測結果

ID   = inputs[0].tolist().index(103)
best = scores[0, ID].argmax(-1).item()
token_best = tokenizer.convert_ids_to_tokens(best)
token_best # 正しくはないが，自然な文章ではある

'western'

In [10]:
# 上位の結果

text = 'the spiritual heart of the [MASK] world'

def topk(k):
    
    ID = inputs[0].tolist().index(103) 
    topk = scores[0, ID].topk(k)
    ids_topk = topk.indices # トークンのID
    tokens_topk = tokenizer.convert_ids_to_tokens(ids_topk) # トークン
    scores_topk = topk.values.cpu().numpy() # スコア

    # 文章中の[MASK]を上で求めたトークンで置き換える。
    text_topk = [] # 穴埋めされたテキストを追加する。
    for token in tokens_topk:
        text_topk.append(text.replace('[MASK]', token, 1))

    return text_topk

topk(43) # sumoは43位.日本ぽい単語が多い（文章中にJapaneseといった単語は無いにも関わらず）．

['the spiritual heart of the western world',
 'the spiritual heart of the modern world',
 'the spiritual heart of the japanese world',
 'the spiritual heart of the buddhist world',
 'the spiritual heart of the muslim world',
 'the spiritual heart of the christian world',
 'the spiritual heart of the anime world',
 'the spiritual heart of the islamic world',
 'the spiritual heart of the spirit world',
 'the spiritual heart of the asian world',
 'the spiritual heart of the entire world',
 'the spiritual heart of the human world',
 'the spiritual heart of the spiritual world',
 'the spiritual heart of the magical world',
 'the spiritual heart of the baseball world',
 'the spiritual heart of the new world',
 'the spiritual heart of the buddhism world',
 'the spiritual heart of the whole world',
 'the spiritual heart of the ancient world',
 'the spiritual heart of the natural world',
 'the spiritual heart of the outside world',
 'the spiritual heart of the demon world',
 'the spiritual hear