<a href="https://colab.research.google.com/github/mathluva/BERT-QA/blob/main/BERT_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
!pip install transformers
!pip install torch



In [7]:
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import torch
import numpy as np

In [8]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [9]:
#BERT tokenizer
tokenizer_for_bert = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [12]:
def bert_answering_machine(question, passage):

    #tokenize input question and passage, add special tokens
    input_ids = tokenizer_for_bert.encode(question, passage)


    cls_index = input_ids.index(102) #index of first SEP token
    len_question = cls_index +1 #length of question (sentence 1)
    len_answer = len(input_ids) - len_question #length of sentence 2

    #segment ids to distinguish between the two sentences
    #apply 0 for sentence 1 and apply 1 for sentence 2
    segment_ids = [0]*len_question + [1]*len_answer

    
    
   
    

   #Converting token ids to tokens
    tokens = tokenizer_for_bert.convert_ids_to_tokens(input_ids) 
    
    
    # getting start and end scores for answer. Converting input arrays to torch tensors before passing to the model
    start_token_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]) )[0]
    end_token_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]) )[1]

    #Converting scores tensors to numpy arrays so that we can use numpy functions
    start_token_scores = start_token_scores.detach().numpy().flatten()
    end_token_scores = end_token_scores.detach().numpy().flatten()
    
    #Picking start index and end index of answer based on start/end indices with highest scores
    answer_start_index = np.argmax(start_token_scores)
    answer_end_index = np.argmax(end_token_scores)

    #Getting scores for start token and end token of the answer. Also rounding it to 2 decimal digits
    start_token_score = np.round(start_token_scores[answer_start_index], 2)
    end_token_score = np.round(end_token_scores[answer_end_index], 2)



    answer = tokens[answer_start_index] #Answer starts with start index, we got based on highest score
    for i in range(answer_start_index + 1, answer_end_index + 1):
        if tokens[i][0:2] == '##':  # Token for a splitted word starts with ##
            answer += tokens[i][2:] # If token start with ## we remove ## and combine it with previous word so as to restore the unsplitted word
        else:
            answer += ' ' + tokens[i]  # If token does not start with ## we just put a space in between while combining tokens
            
    # Few patterns indicating that BERT does not get answer from the passage for question asked
    if ( answer_start_index == 0) or (start_token_score < 0 ) or  (answer == '[SEP]') or ( answer_end_index <  answer_start_index):
        answer = "Sorry!, I could not find  an answer in the passage."
    
    return ( answer_start_index, answer_end_index, start_token_score, end_token_score,  answer)



In [13]:
_,_,_,_,ans = bert_answering_machine ("Where does Tiffany's mom live?", "Tiffany was born in Norfolk, VA.  Her mom, Cora, was born in Columbus, Ohio.  Cora has also lived in Fl, GA and NJ.  She currently resides in Raleigh,NC.")
print(ans)

raleigh , nc


In [16]:
_,_,_,_,ans = bert_answering_machine ("What is Cora's daughter name?", "Tiffany was born in Norfolk, VA.  Her mom, Cora, was born in Columbus, Ohio.  Cora has also lived in Fl, GA and NJ.  She currently resides in Raleigh,NC.")
print(ans)

tiffany


In [17]:
_,_,_,_,ans = bert_answering_machine ("What states did Tiffany's mom live in?", "Tiffany was born in Norfolk, VA.  Her mom, Cora, was born in Columbus, Ohio.  Cora has also lived in Fl, GA and NJ.  She currently resides in Raleigh,NC.")
print(ans)

fl , ga and nj


In [22]:
_,_,_,_,ans = bert_answering_machine ("What is Louis's occupation?","Louis is a landscape technician, mechanic and chef. On the weekends, he enjoys spending time with his wife Tiffany and working on his vehicles.")
print(ans)

landscape technician , mechanic and chef


In [24]:
_,_,_,_,ans = bert_answering_machine ("What does Louis like to during his spare time?","Louis is a landscape technician, mechanic and chef. On the weekends, he enjoys spending time with his wife Tiffany and working on his vehicles.")
print(ans)

working on his vehicles


In [25]:
_,_,_,_,ans = bert_answering_machine ("Is Joy's sister married?","Joy is a teacher and law student.  She wants to practice restorative justice. She plans to take the bar in February of 2022.  Her older sister, Anika, lives in florida with her wife.")
print(ans)

her wife


In [26]:
_,_,_,_,ans = bert_answering_machine ("What does Joy plan to do in February?","Joy is a teacher and law student.  She wants to practice restorative justice. She plans to take the bar in February of 2022.  Her older sister, Anika, lives in florida with her wife.")
print(ans)

take the bar


In [27]:
_,_,_,_,ans = bert_answering_machine ("What is restorative justice?","Joy is a teacher and law student.  She wants to practice restorative justice. She plans to take the bar in February of 2022.  Her older sister, Anika, lives in florida with her wife.")
print(ans)

joy is a teacher and law student


In [29]:
_,_,_,_,ans = bert_answering_machine ("What job does Joy want after graduation?","Joy is a teacher and law student.  She wants to practice restorative justice. She plans to take the bar in February of 2022.  Her older sister, Anika, lives in florida with her wife.")
print(ans)

teacher


In [30]:
_,_,_,_,ans = bert_answering_machine ("What job does joy want?","Joy is a teacher and law student.  She wants to practice restorative justice. She plans to take the bar in February of 2022.  Her older sister, Anika, lives in florida with her wife.")
print(ans)

teacher and law student . she wants to practice restorative justice


In [31]:
_,_,_,_,ans = bert_answering_machine ("Where does her sister live?","Joy is a teacher and law student.  She wants to practice restorative justice. She plans to take the bar in February of 2022.  Her older sister, Anika, lives in florida with her wife.")
print(ans)

florida
