In [None]:
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

import torch


In [None]:
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)

In [None]:
question = "What is the capital of France?"
# give a bigger text 
text = (
    "France, officially the French Republic, is a country whose territory consists of metropolitan France in Western Europe "
    "and several overseas regions and territories. The capital of France is Paris, which is also the largest city in the country. "
    "France is known for its rich history, culture, and influence in various fields such as art, science, and politics."
)


In [None]:
encoding = tokenizer.encode_plus(question, text)
print(encoding)

In [None]:
inputs = encoding["input_ids"]
print(inputs)
sqeuence_embeddings = encoding["token_type_ids"]
print(sqeuence_embeddings)
tokens = tokenizer.convert_ids_to_tokens(inputs)
print(tokens)

In [None]:
output = model(
    input_ids=torch.tensor([inputs]),
    token_type_ids=torch.tensor([sqeuence_embeddings])
)


In [None]:
start_index = torch.argmax(output.start_logits)
end_index = torch.argmax(output.end_logits)
print (f"Start index: {start_index}, End index: {end_index}")
answer = tokenizer.convert_tokens_to_string(tokens[start_index:end_index + 1])
print(f"Answer: {answer}")
if answer.strip() == "":
    print("No answer found.")
else:
    print(f"Answer found: {answer.strip()}")

In [None]:
import matplotlib as plt
import seaborn as sns   


In [None]:
s_scores = output.start_logits[0].detach().numpy().flatten()
print(f"Start scores: {s_scores}")
e_scores = output.end_logits[0].detach().numpy().flatten()
print(f"End scores: {e_scores}")


In [None]:

token_labels = []
for i, token in enumerate(tokens):
    token_labels.append(f'{token} - {i:>2}')
print("Token labels:")
for label in token_labels:
    print(label)

In [None]:
# Create figure with larger size for better readability
plt.figure(figsize=(15, 6))

# Create bar positions
x_pos = range(len(token_labels))

# Plot start scores in blue and end scores in red
plt.bar([i - 0.2 for i in x_pos], s_scores, width=0.4, label='Start scores', alpha=0.7, color='blue')
plt.bar([i + 0.2 for i in x_pos], e_scores, width=0.4, label='End scores', alpha=0.7, color='red')

# Customize the plot
plt.xticks(x_pos, token_labels, rotation=90, ha='center')
plt.grid(True, alpha=0.3)
plt.legend()
plt.xlabel('Tokens')
plt.ylabel('Scores')
plt.title('BERT Question Answering: Start vs End Token Scores')
plt.tight_layout()
plt.show()