In [43]:
from transformers import AlbertTokenizer, AlbertForQuestionAnswering
import torch
import numpy as np

In [44]:
tokenizer = AlbertTokenizer.from_pretrained('twmkn9/albert-base-v2-squad2')
model = AlbertForQuestionAnswering.from_pretrained('twmkn9/albert-base-v2-squad2')
# inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
# start_positions = torch.tensor([1])
# end_positions = torch.tensor([3])
# outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
# loss, start_scores, end_scores = outputs[:3]


In [75]:
question = "What was the amount of children murdered?"
answer_text = 'NEW DELHI, India (CNN) -- A high court in northern India on Friday acquitted a wealthy businessman facing the death sentence for the killing of a teen in a case dubbed "the house of horrors."\n\n\n\nMoninder Singh Pandher was sentenced to death by a lower court in February.\n\n\n\nThe teen was one of 19 victims -- children and young women -- in one of the most gruesome serial killings in India in recent years.\n\n\n\nThe Allahabad high court has acquitted Moninder Singh Pandher, his lawyer Sikandar B. Kochar told CNN.\n\n\n\nPandher and his domestic employee Surinder Koli were sentenced to death in February by a lower court for the rape and murder of the 14-year-old.\n\n\n\nThe high court upheld Koli\'s death sentence, Kochar said.\n\n\n\nThe two were arrested two years ago after body parts packed in plastic bags were found near their home in Noida, a New Delhi suburb. Their home was later dubbed a "house of horrors" by the Indian media.\n\n\n\nPandher was not named a main suspect by investigators initially, but was summoned as co-accused during the trial, Kochar said.\n\n\n\nKochar said his client was in Australia when the teen was raped and killed.\n\n\n\nPandher faces trial in the remaining 18 killings and could remain in custody, the attorney said.'

In [76]:
input_ids = tokenizer.encode(question, answer_text)
print('The input has a total of {:} tokens.'.format(len(input_ids)))


The input has a total of 293 tokens.


In [88]:
tokens = tokenizer.convert_ids_to_tokens(input_ids)

# For each token and its id...
for token, id in zip(tokens, input_ids):
    
    # If this is the [SEP] token, add some space around it to make it stand out.
    if id == tokenizer.sep_token_id:
        print('')
    
    # Print the token string and its ID in two columns.
    print('{:<12} {:>6,}'.format(token, id))

    if id == tokenizer.sep_token_id:
        print('')

[CLS]             2
▁what            98
▁was             23
▁the             14
▁amount       2,006
▁of              16
▁children       391
▁murdered     6,103
?                60

[SEP]             3

▁new             78
▁delhi        5,999
,                15
▁india          739
▁                13
(                 5
cn            9,881
n               103
)                 6
▁                13
-                 8
-                 8
▁a               21
▁high           183
▁court          495
▁in              19
▁northern       743
▁india          739
▁on              27
▁friday       4,619
▁acquitted   20,649
▁a               21
▁wealthy      6,574
▁businessman  5,960
▁facing       4,325
▁the             14
▁death          372
▁sentence     5,123
▁for             26
▁the             14
▁killing      2,389
▁of              16
▁a               21
▁teen         9,503
▁in              19
▁a               21
▁case           610
▁dubbed       9,343
▁                13
"                 

In [89]:
# Search the input_ids for the first instance of the `[SEP]` token.
sep_index = input_ids.index(tokenizer.sep_token_id)

# The number of segment A tokens includes the [SEP] token istelf.
num_seg_a = sep_index + 1

# The remainder are segment B.
num_seg_b = len(input_ids) - num_seg_a

# Construct the list of 0s and 1s.
segment_ids = [0]*num_seg_a + [1]*num_seg_b

# There should be a segment_id for every input token.
assert len(segment_ids) == len(input_ids)

In [90]:
# Run our example through the model.
start_scores, end_scores = model(torch.tensor([input_ids]), # The tokens representing our input text.
                                 token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from answer_text


In [91]:
# Find the tokens with the highest `start` and `end` scores.
k = 3 # number of top answers returned
start_scores = start_scores.detach().numpy()[0]
start_scores[0] = 0
end_scores = end_scores.detach().numpy()[0]
end_scores[0] = 0
# answer_start = torch.topk(start_scores, k, largest=True)
# answer_end = torch.topk(end_scores, k, largest=True)

In [81]:

def gen_i_j_scores(start_scores, end_scores):
    '''
    i = start score index
    j = end score index
    score = sum of start_scores[i] and start_scores[j]
    '''
    start_scores_copy = start_scores.copy()
    end_scores_copy = end_scores.copy()
    i_j_scores = {}
    for x in range(len(start_scores)):
        score = 0
        i = np.argmax(start_scores_copy)
        score += start_scores_copy[i]
        start_scores_copy[i] = 0
        j = np.argmax(end_scores_copy)
        score += end_scores_copy[j]
        end_scores_copy[j] = 0
        i_j_scores[(i, j)] = score
    return i_j_scores

In [82]:
sorted_i_j = []
for key in gen_i_j_scores(start_scores, end_scores):
    sorted_i_j.append(key)

In [83]:
counter = 0
answers_indices = [] # i, j pairs of answers
answers = [] # string answers
# for index, pair in enumerate(sorted_i_j): 
while len(answers_indices) < k and counter < len(sorted_i_j):
    pair = sorted_i_j[counter]
    if pair[1] >= pair[0]: #end token after start token
        if counter == 0: # first i,j
            answer = tokens[pair[0]] #first token (i)
            for i in range(pair[0] + 1, pair[1] + 1):
                if tokens[i][0:1] != '▁':
                    answer += tokens[i]
                else:
                    answer += ' ' + tokens[i]
            print('Answer: "' + answer + '"')
            answers_indices.append(sorted_i_j[counter])
            answers.append(answer)
        elif (pair[0] >= sorted_i_j[counter-1][1]) or (sorted_i_j[counter-1][0] >= pair[1]):
            # ^start token of current span is after end token of previous span; end token of current span is before start of prev
            answer = tokens[pair[0]]
            for i in range(pair[0] + 1, pair[1] + 1):
                if tokens[i][0:1] != '▁':
                    answer += tokens[i]
                else:
                    answer += ' ' + tokens[i]
            print('Answer: "' + answer + '"')
            answers_indices.append(sorted_i_j[counter])
            answers.append(answer)
    counter += 1


Answer: "▁19 ▁victims ▁-- ▁children ▁and ▁young ▁women"
Answer: "▁18"
Answer: "▁and ▁young ▁women ▁-- ▁in ▁one ▁of ▁the ▁most ▁gruesome ▁serial ▁killings ▁in ▁india ▁in ▁recent ▁years"


In [84]:
if len(answers_indices) < k:
    while len(answers_indices) < k:
        answers_indices.append((-1, -1))

In [85]:
for i in answers_indices:
    if i == (-1, -1):
        answers.append('')

In [86]:
print(answers)

['▁19 ▁victims ▁-- ▁children ▁and ▁young ▁women', '▁18', '▁and ▁young ▁women ▁-- ▁in ▁one ▁of ▁the ▁most ▁gruesome ▁serial ▁killings ▁in ▁india ▁in ▁recent ▁years']
