# Bert QA サンプル

https://towardsdatascience.com/question-answering-with-a-fine-tuned-bert-bc4dafd45626

## ライブラリ導入

In [1]:
# transformersの導入
!pip install transformers | tail -n 1

Successfully installed huggingface-hub-0.4.0 pyyaml-6.0 sacremoses-0.0.49 tokenizers-0.11.6 transformers-4.17.0


In [31]:
# ライブラリのインポート
import pandas as pd
import numpy as np
from IPython.display import display

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

## テキストデータ準備

### coqa データ読み込み

https://stanfordnlp.github.io/coqa/

CoQAは、質問応答システムを構築するための大規模なデータセットです。CoQAチャレンジの目標は、テキストのパッセージを理解し、会話に現れる一連の相互に関連する質問に答える機械学習モデルの性能を測定することです。

In [32]:
# 訓練用データ読み込み
coqa = pd.read_json('http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json')
# データの確認
display(coqa.head())

Unnamed: 0,version,data
0,1,"{'source': 'wikipedia', 'id': '3zotghdk5ibi9ce..."
1,1,"{'source': 'cnn', 'id': '3wj1oxy92agboo5nlq4r7..."
2,1,"{'source': 'gutenberg', 'id': '3bdcf01ogxu7zdn..."
3,1,"{'source': 'cnn', 'id': '3ewijtffvo7wwchw6rtya..."
4,1,"{'source': 'gutenberg', 'id': '3urfvvm165iantk..."


### 読み込みデータ確認

In [4]:
# データの先頭行の内容表示
item = coqa.loc[0,'data']
print(item)

{'source': 'wikipedia', 'id': '3zotghdk5ibi9cex97fepx7jetpso7', 'filename': 'Vatican_Library.txt', 'story': 'The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be 

In [5]:
# キーの確認
print(item.keys())

dict_keys(['source', 'id', 'filename', 'story', 'questions', 'answers', 'name'])


In [6]:
# storyの表示
print(item['story'])

The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. 

The Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. 

In March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be made available online. 

The Vatican Secret Archives were separated from the library at the beginning of the 17th

In [7]:
# questionsの表示
for text in item['questions']:
    print(text)

{'input_text': 'When was the Vat formally opened?', 'turn_id': 1}
{'input_text': 'what is the library for?', 'turn_id': 2}
{'input_text': 'for what subjects?', 'turn_id': 3}
{'input_text': 'and?', 'turn_id': 4}
{'input_text': 'what was started in 2014?', 'turn_id': 5}
{'input_text': 'how do scholars divide the library?', 'turn_id': 6}
{'input_text': 'how many?', 'turn_id': 7}
{'input_text': 'what is the official name of the Vat?', 'turn_id': 8}
{'input_text': 'where is it?', 'turn_id': 9}
{'input_text': 'how many printed books does it contain?', 'turn_id': 10}
{'input_text': 'when were the Secret Archives moved from the rest of the library?', 'turn_id': 11}
{'input_text': 'how many items are in this secret collection?', 'turn_id': 12}
{'input_text': 'Can anyone use this library?', 'turn_id': 13}
{'input_text': 'what must be requested to view?', 'turn_id': 14}
{'input_text': 'what must be requested in person or by mail?', 'turn_id': 15}
{'input_text': 'of what books?', 'turn_id': 16}
{'

In [8]:
# answersの表示
for text in item['answers']:
    print(text)

{'span_start': 151, 'span_end': 179, 'span_text': 'Formally established in 1475', 'input_text': 'It was formally established in 1475', 'turn_id': 1}
{'span_start': 454, 'span_end': 494, 'span_text': 'he Vatican Library is a research library', 'input_text': 'research', 'turn_id': 2}
{'span_start': 457, 'span_end': 511, 'span_text': 'Vatican Library is a research library for history, law', 'input_text': 'history, and law', 'turn_id': 3}
{'span_start': 457, 'span_end': 545, 'span_text': 'Vatican Library is a research library for history, law, philosophy, science and theology', 'input_text': 'philosophy, science and theology', 'turn_id': 4}
{'span_start': 769, 'span_end': 879, 'span_text': 'March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts', 'input_text': 'a  project', 'turn_id': 5}
{'span_start': 1048, 'span_end': 1127, 'span_text': 'Scholars have traditionally divided the history of the library into five period', 'input_text': 

### データ加工

In [9]:
#　テキスト(text)、質問(question)、回答(answer)の抽出
# 一つのテキストに対して質問、回答のペアは複数対応
cols = ["text","question","answer"]

# 抽出リストの１行分
comp_list = []
for index, row in coqa.iterrows():

    # 質問の個数だけ繰り返し
    for i in range(len(row["data"]["questions"])):
        temp_list = []

        # text
        temp_list.append(row["data"]["story"])

        # i番目の質問
        temp_list.append(row["data"]["questions"][i]["input_text"])

        # i番目の回答
        temp_list.append(row["data"]["answers"][i]["input_text"])

        # リストのリストを生成
        comp_list.append(temp_list)

# comp_listからデータフレームを生成
data = pd.DataFrame(comp_list, columns=cols) 

#　２度目以降のために、csvファイルとしても保存
data.to_csv("CoQA_data.csv", index=False)

### 加工後データ確認

In [10]:
# 先頭と、最後の内容表示
display(data.head())
display(data.tail())

Unnamed: 0,text,question,answer
0,"The Vatican Apostolic Library (), more commonl...",When was the Vat formally opened?,It was formally established in 1475
1,"The Vatican Apostolic Library (), more commonl...",what is the library for?,research
2,"The Vatican Apostolic Library (), more commonl...",for what subjects?,"history, and law"
3,"The Vatican Apostolic Library (), more commonl...",and?,"philosophy, science and theology"
4,"The Vatican Apostolic Library (), more commonl...",what was started in 2014?,a project


Unnamed: 0,text,question,answer
108642,(CNN) -- Cristiano Ronaldo provided the perfec...,Who was a sub?,Xabi Alonso
108643,(CNN) -- Cristiano Ronaldo provided the perfec...,Was it his first game this year?,Yes
108644,(CNN) -- Cristiano Ronaldo provided the perfec...,What position did the team reach?,third
108645,(CNN) -- Cristiano Ronaldo provided the perfec...,Who was ahead of them?,Barca.
108646,(CNN) -- Cristiano Ronaldo provided the perfec...,By how much?,six points


In [30]:
# 10番目の要素ののtext, question, answerの確認
index = 10
print(f'Text: \n{data.loc[index].text}\n' )
print(f'Question: {data.loc[index].question}\n' )
print(f'Answer: {data.loc[index].answer}\n')

Text: 
The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. 

The Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. 

In March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscripts, to be made available online. 

The Vatican Secret Archives were separated from the library at the beginning of t

In [12]:
# データの総数
print("Number of question and answers: ", len(data))

Number of question and answers:  108647


## BERT QAモデル読み込み
今回利用するモデルは、QA形式のファインチューニングも済んでいて、入力データを渡すだけで結果を返すことが可能

In [14]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

Downloading:   0%|          | 0.00/443 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

## BERTによる予測

### 特定の１セット取得

In [33]:
# 特定の１セットを抽出
index = 10
text = data["text"][index]
question = data["question"][index]
answer = data["answer"][index]

### 入力データのエンコード

In [17]:
# 質問、テキストの組をエンコードする
input_ids = tokenizer.encode(question, text)

# エンコードの値から逆向きにトークンの一覧を取得
tokens = tokenizer.convert_ids_to_tokens(input_ids)

### エンコード結果の確認

In [18]:
# input_idsの長さを計算
input_len = len(input_ids)
print(f'入力文字列の長さ: {input_len}')

入力文字列の長さ: 294


In [19]:
# input_idsの先頭２０要素
print(input_ids[:20])

# tokensの先頭20要素
print(tokens[:20])

# 先頭20要素をinput_ids, tokensと対応付けて表示
for token, id in zip(tokens[:20], input_ids[:20]):
    print('{:8}{:8,}'.format(token,id))

[101, 2043, 2020, 1996, 3595, 8264, 2333, 2013, 1996, 2717, 1997, 1996, 3075, 1029, 102, 1996, 12111, 11815, 3075, 1006]
['[CLS]', 'when', 'were', 'the', 'secret', 'archives', 'moved', 'from', 'the', 'rest', 'of', 'the', 'library', '?', '[SEP]', 'the', 'vatican', 'apostolic', 'library', '(']
[CLS]        101
when       2,043
were       2,020
the        1,996
secret     3,595
archives   8,264
moved      2,333
from       2,013
the        1,996
rest       2,717
of         1,997
the        1,996
library    3,075
?          1,029
[SEP]        102
the        1,996
vatican   12,111
apostolic  11,815
library    3,075
(          1,006


### segment_idsの計算
QAモデルでは入力として、インデックス化したテキストだけでなく、segmentベクトルを引数で渡す必要がある。  
以下の実装でその準備を行う。

In [20]:
#　[SEP] tokenの最初の位置
sep_idx = input_ids.index(tokenizer.sep_token_id)
print("SEP token index: ", sep_idx)

SEP token index:  14


In [22]:
#　セグメントAのトークン数　
# (pythonのindexはゼロから始まるので、sep token indexより１大きい)
num_seg_a = sep_idx+1
print("Number of tokens in segment A: ", num_seg_a)

Number of tokens in segment A:  15


In [23]:
# セグメントBのトークン数
num_seg_b = len(input_ids) - num_seg_a
print("Number of tokens in segment B: ", num_seg_b)

Number of tokens in segment B:  279


In [24]:
# segment_idsの計算
segment_ids = [0]*num_seg_a + [1]*num_seg_b
print(segment_ids)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [25]:
#　segmeind_idsの長さとinput_idsの長さが一致していることの確認
assert len(segment_ids) == len(input_ids)

### 予測の実施

In [26]:
# inpit_idsとsegment_idsを用いて予測の実施
output = model(torch.tensor([input_ids]),  token_type_ids=torch.tensor([segment_ids]))

In [27]:
# 結果の確認
print(output)

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[-4.7418, -6.1644, -8.3419, -8.5675, -9.0231, -9.4152, -7.4713, -7.6652,
         -7.9159, -8.4479, -8.3411, -8.4345, -8.8540, -9.9689, -4.7418, -7.0036,
         -6.0015, -7.4322, -7.4095, -9.0583, -8.3897, -8.6382, -7.5039, -8.0999,
         -8.1766, -7.6513, -6.2807, -8.0295, -8.8092, -8.1086, -8.0572, -6.1988,
         -8.8412, -8.8514, -7.6907, -7.9228, -7.0875, -8.6589, -8.3033, -6.9598,
         -8.2578, -8.9226, -7.2532, -8.2714, -5.2682, -7.7425, -8.1994, -5.3182,
         -6.1210, -6.2023, -2.2322, -5.5933, -8.2697, -7.2247, -7.7075, -8.3552,
         -7.7833, -7.7108, -8.7252, -6.9845, -8.4068, -7.4159, -8.3782, -5.1614,
         -7.2312, -7.7564, -8.5597, -8.5845, -7.6947, -9.1013, -7.5146, -7.6808,
         -8.6196, -7.1953, -8.3952, -8.1451, -7.1554, -8.8171, -7.5913, -7.3029,
         -8.5125, -7.4493, -8.1198, -6.9437, -8.6964, -8.6875, -7.1083, -8.3913,
         -8.6075, -8.1642, -7.9004, -8.9003, -8.6015, -8

In [28]:
# answer_start と answer_endの計算
answer_start = torch.argmax(output.start_logits)
answer_end = torch.argmax(output.end_logits)
if answer_end >= answer_start:
    answer = " ".join(tokens[answer_start:answer_end+1])
else:
    answer = ("I am unable to find the answer to this question. Can you please ask another question?")

#  結果の確認
print(answer_start, answer_end, answer)


tensor(209) tensor(213) beginning of the 17th century


In [29]:
question_cap = question.capitalize()
answer_cap = answer.capitalize()

print(f"\nQuestion:\n{question_cap}")
print(f"\nAnswer:\n{answer_cap}")


Question:
When were the secret archives moved from the rest of the library?

Answer:
Beginning of the 17th century
