In [None]:
evaluation_strategy = "!pip install transformers

In [None]:
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
import matplotlib.pyplot as plt

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")



In [None]:
model.base_model.config

In [None]:
pytorch_total_params = sum(p.numel() for p in model.base_model.parameters() if p.requires_grad)    
pytorch_trainable_params = sum(p.numel() for p in model.base_model.parameters() )    
print("Total number of params", pytorch_total_params)
print("Total number of trainable params", pytorch_trainable_params)

In [None]:
text = r"""Japan is the eleventh-most populous country in the world, as well as one of the most densely populated and urbanized.
 About three-fourths of the country's terrain is mountainous, concentrating its population of 125.57 million on narrow coastal plains. 
 Japan is divided into 47 administrative prefectures and eight traditional regions.
 Osaka has a big population of 16 million. 
 The Greater Tokyo Area is the most populous metropolitan area in the world, with more than 37.4 million residents. 
"""

In [None]:
import numpy as np
def get_top_answers(possible_starts,possible_ends,input_ids):
  answers = []
  for start,end in zip(possible_starts,possible_ends):
    #+1 for end
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[start:end+1]))
    answers.append( answer )
  return answers  

def answer_question(question,context,topN):

    inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")
    
    input_ids = inputs["input_ids"].tolist()[0]

    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    model_out = model(**inputs)
     
    answer_start_scores = model_out["start_logits"]
    answer_end_scores = model_out["end_logits"]

    possible_starts = np.argsort(answer_start_scores.cpu().detach().numpy()).flatten()[::-1][:topN]
    possible_ends = np.argsort(answer_end_scores.cpu().detach().numpy()).flatten()[::-1][:topN]
    
    #get best answer
    answer_start = torch.argmax(answer_start_scores)  
    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

    answers = get_top_answers(possible_starts,possible_ends,input_ids )

    return { "answer":answer,"answer_start":answer_start,"answer_end":answer_end,"input_ids":input_ids,
            "answer_start_scores":answer_start_scores,"answer_end_scores":answer_end_scores,"inputs":inputs,"answers":answers,
            "possible_starts":possible_starts,"possible_ends":possible_ends}

In [None]:
questions = [
    "How many states in Japan?",
    "What is the population of Japan?",
    "Which city is most crowded in the world?",
    "What is the city with most population?",
    "What is the topic here?",
    "What are we talking about?",
    "What is the main idea here?"
]

In [None]:
questions = [    
    "Where is the most populous metropolitan area in the world ?",
    "Where is the most populous city ?",
    "What is the most populous city ?",
    "Where is most populous in the world ?",
    "Where is most populous?",
    "What is Population of Tokyo ?",
    "Which city is most crowded in the world ?",
    "Which city has biggest population ?",
    "Which city has most population ?",
    ]

questions3 = [
    "Most populous city ?"  ,    
    #"Most populous city?"  ,       
    "Most populous city"  
    
    ]

for q in questions:
  answer_map = answer_question(q,text,5)    
  print("Question:",q)
  print("Answers:")
  [print((index+1)," ) ",ans) for index,ans in  enumerate(answer_map["answers"]) if len(ans) > 0 ]
 
  

In [None]:
answer_map = answer_question("Where is most populous in the world?",text,3)

In [None]:
print("input_ids:",answer_map["inputs"]["input_ids"] )
print("token_type_ids:",answer_map["inputs"]["token_type_ids"] )
print("attention_mask:",answer_map["inputs"]["attention_mask"] )
#answer_map["inputs"]

In [None]:
print( len(answer_map["input_ids"] ))
tokenizer.decode( answer_map["input_ids"] )  

In [None]:
def plot_possible_answer(answer_map,expected_start,expected_end):
  start_scores = answer_map["answer_start_scores"]
  end_scores = answer_map["answer_end_scores"]
  tokens = tokenizer.decode( answer_map["input_ids"] ).split(" ") 
  print("tokens",len(tokens),"---",len(answer_map["input_ids"]))
  tokens_ind = [ tokenizer.decode(t) for t in answer_map["input_ids"] ]
  print("tokens_ind",len(tokens_ind))
  fig,axes = plt.subplots(2,1)
  y_start = start_scores.detach().numpy().flatten()
  x_start = [i for i in range(len(y_start))]

  y_end = end_scores.detach().numpy().flatten()
  x_end = [i for i in range(len(y_end))]

  axes[0].bar(tokens_ind,y_start)
  axes[0].set_title("start scores( "+ str( len( y_start ) ) +")" )
  axes[0].figure.set_size_inches(20, 5)
  #axes[0].xaxis.set_labels( tokens_ind )
  axes[0].xaxis.set_tick_params(rotation=90)
  axes[0].axvline(expected_start,color="yellow")
  
  axes[1].bar(tokens_ind,y_end, color="orange")
  axes[1].set_title("end scores( "+ str( len( y_end ) ) +")" )
  axes[1].axvline(expected_end,color="red")
  axes[1].xaxis.set_tick_params(rotation=90)

  axes[0].autoscale(tight=True)
  axes[1].autoscale(tight=True)
  fig.tight_layout()


In [None]:
plot_possible_answer(answer_map,10,11)