## Use Hugging Face pre-trained BERT QA model

In [1]:
#!pip install transformers

In [2]:
import pandas as pd
import numpy as np
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

Load pre-trained model traing on sqaud data

In [3]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

QA Method

In [4]:
def question_answer(question, text):
    
    #tokenize question and text as a pair
    input_ids = tokenizer.encode(question, text)
    
    #string version of tokenized ids
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    #segment IDs
    #first occurence of [SEP] token
    sep_idx = input_ids.index(tokenizer.sep_token_id)
    #number of tokens in segment A (question)
    num_seg_a = sep_idx+1
    #number of tokens in segment B (text)
    num_seg_b = len(input_ids) - num_seg_a
    
    #list of 0s and 1s for segment embeddings
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    assert len(segment_ids) == len(input_ids)
    
    #model output using input_ids and segment_ids
    output = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]))
    
    #reconstructing the answer
    answer_start = torch.argmax(output.start_logits)
    answer_end = torch.argmax(output.end_logits)
    if answer_end >= answer_start:
        answer = tokens[answer_start]
        for i in range(answer_start+1, answer_end+1):
            if tokens[i][0:2] == "##":
                answer += tokens[i][2:]
            else:
                answer += " " + tokens[i]
                
    if answer.startswith("[CLS]"):
        answer = "Unable to find the answer to your question."
    
    print("\nPredicted answer:\n{}".format(answer.capitalize()))

Test Model

In [7]:
text = input("Please enter your text: \n")
question = input("\nPlease enter your question: \n")
while True:
    question_answer(question, text)
    
    flag = True
    flag_N = False
    
    while flag:
        response = input("\nDo you want to ask another question based on this text (Y/N)? ")
        if response[0] == "Y":
            question = input("\nPlease enter your question: \n")
            flag = False
        elif response[0] == "N":
            print("\nBye!")
            flag = False
            flag_N = True
            
    if flag_N == True:
        break

Please enter your text: 
Indian boxer Satish Kumar lost by a unanimous decision against Bakhodir Jalolov of Uzbekistan in the men's Super Heavy (+91kg) category quarterfinal today. In equestrian, India's Fouaad Mirza and Seigneur Medicott were at the 22nd position with a total penalty points of 39.20. The eventing, jumping Qualifiers will be held on August 2. Golfer Anirban Lahiri ended Round 4 playing 5 under par (-5) while Udayan Mane finished Round 4 with playing 3 over par (+3). Other Indians in action on Sunday, August 1 at the Tokyo Olympics include shuttler PV Sindhu and the men's hockey team. Sindhu will play for the bronze medal after facing defeat in the semi-final match on Saturday. The men's hockey team will take on Great Britain with the hope of making the last four.  World number one tennis player Novak Djokovic of Serbia went down in three sets to Spain's Pablo Carreno Busta in the bronze medal match of men's tennis singles event yesterday while China's Ma Long became th