# Question Answer in NLP

Implementing Question Answering model with the BertForQuestionAnswering model fine tuned on Stanford Question Answering Dataset (SQuAD).

In [1]:
import torch
from transformers import BertForQuestionAnswering, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


Load the model for QA

In [2]:
model = BertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Sample QA

In [3]:
question = "Where is the Great Barrier Reef located?"

In [5]:
answer_text = "The Great Barrier Reef is located in the Coral Sea, off the coast of Australia. It is the largest coral reef system in the world, stretching over 2,300 km and covering an area of approximately 344,400 km². The Great Barrier Reef is home to a diverse range of marine life and is considered one of the seven natural wonders of the world. It is also a UNESCO World Heritage Site threatened by climate change and other environmental factors."

Tokenize input id

In [6]:
input_ids = tokenizer.encode(question, answer_text)
input_ids

[101,
 2073,
 2003,
 1996,
 2307,
 8803,
 12664,
 2284,
 1029,
 102,
 1996,
 2307,
 8803,
 12664,
 2003,
 2284,
 1999,
 1996,
 11034,
 2712,
 1010,
 2125,
 1996,
 3023,
 1997,
 2660,
 1012,
 2009,
 2003,
 1996,
 2922,
 11034,
 12664,
 2291,
 1999,
 1996,
 2088,
 1010,
 10917,
 2058,
 1016,
 1010,
 3998,
 2463,
 1998,
 5266,
 2019,
 2181,
 1997,
 3155,
 29386,
 1010,
 4278,
 3186,
 1012,
 1996,
 2307,
 8803,
 12664,
 2003,
 2188,
 2000,
 1037,
 7578,
 2846,
 1997,
 3884,
 2166,
 1998,
 2003,
 2641,
 2028,
 1997,
 1996,
 2698,
 3019,
 16278,
 1997,
 1996,
 2088,
 1012,
 2009,
 2003,
 2036,
 1037,
 12239,
 2088,
 4348,
 2609,
 5561,
 2011,
 4785,
 2689,
 1998,
 2060,
 4483,
 5876,
 1012,
 102]

In [9]:
attention_mask = [1] * len(input_ids)
len(attention_mask)

99

get the answer

In [10]:
output = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))

In [11]:
output

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[-6.4106, -5.9257, -7.6330, -7.7274, -8.1191, -8.6230, -8.6095, -7.8516,
         -8.8621, -6.4106,  1.8369, -0.6083, -3.4098, -4.1775, -2.3809,  0.0152,
          4.8730,  6.7679,  7.2950, -1.0516, -3.6109,  1.6546, -2.7443, -2.8114,
         -4.9304,  1.8826, -6.4102, -3.4538, -7.1368, -6.2586, -5.6499, -5.0297,
         -6.4085, -6.6760, -7.3548, -7.1828, -5.7484, -7.8486, -4.7000, -7.1409,
         -5.4418, -8.2328, -7.1212, -6.8992, -8.1358, -5.6112, -7.2474, -6.3876,
         -8.2483, -5.2105, -4.3644, -8.1883, -7.1706, -6.6519, -8.0507, -4.0760,
         -4.1507, -7.0613, -7.5427, -8.0758, -6.9090, -8.7749, -8.1034, -7.0426,
         -8.2258, -8.8082, -6.8397, -7.7390, -8.6729, -7.8738, -7.2368, -7.4564,
         -8.5908, -8.2528, -6.6169, -6.9555, -7.2141, -8.4457, -7.1364, -6.9217,
         -8.2917, -6.2226, -8.0376, -8.0975, -6.7058, -4.2892, -6.3594, -7.3548,
         -7.0044, -6.3021, -8.6310, -7.1910, -8.1402, -8

get the logits

In [12]:
start_index = torch.argmax(output[0][0, :len(input_ids) - input_ids.index(tokenizer.sep_token_id)])
end_index = torch.argmax(output[1][0, :len(input_ids) - input_ids.index(tokenizer.sep_token_id)])

In [13]:
answer = tokenizer.decode(input_ids[start_index:end_index + 1], skip_special_tokens=True)

In [14]:
answer

'coral sea'