In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm, trange

from sklearn.metrics import confusion_matrix, classification_report

import pandas as pd

import pdb

from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  whitespace_tokenize)

import spacy

In [2]:
ner = spacy.load("en_core_web_sm")

In [3]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [4]:
class PredictedSpanFormatter:
    
    def __init__(self,max_answer_length):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_answer_length = max_answer_length #options.max_answer_length
        
    def un_tokenize(self,ids, tokens_to_text_mapping, bert_tokenizer):
        out_list = []
        start = 0
        end = start
        while (start < len(ids)) and (end < len(ids)):
            i = len(ids)
            decoded_anything = False
            while (decoded_anything == False) and (i > start):
                if(tuple(ids[start:i]) in tokens_to_text_mapping.keys()):
                    out_list.append(tokens_to_text_mapping[tuple(ids[start:i])])
                    decoded_anything = True
                else:
                    i -= 1
            if(decoded_anything == False):
                out_list.append(bert_tokenizer.convert_ids_to_tokens([ids[start]])[0])
                start += 1
                end = start
            else:
                start = i
                end = i
        return " ".join(out_list)

    def is_word_split(self,word):
        if(len(word) < 2):
            return False
        else:
            return (word[0] == '#' and word[1] == '#')

    def combine_word_pieces(self, sentence):
        # the first word cant start with '##'
        out_tokens = []
        for token in sentence:
            if(not self.is_word_split(token)):
                out_tokens.append((token))
            else:
                out_tokens[-1] += token[2:]
        return out_tokens
        
    def convert_indices_to_text(self, sentence, start, end, tokens_to_text_mapping):
        ''' (sentence, [10, 12]) --> ['runn', '##ing', 'race'] --> ['running', 'race']
        --> "running race" '''
        text = self.tokenizer.convert_ids_to_tokens(sentence)
        true_start = start
        if(self.is_word_split(text[start])):
            for i in range(1,start):
                if(not self.is_word_split(text[start-i])):
                    true_start = start-i
                    break
        
        true_end = end
        for i in range(end+1, len(sentence)):
            if(not self.is_word_split(text[i])):
                true_end = i-1
                break

        proper_text = self.un_tokenize(sentence[true_start:true_end+1], tokens_to_text_mapping, self.tokenizer)
#         proper_text = " ".join(text[true_start:true_end+1]).replace('  ##','').replace(' ##','')
        return proper_text
        

#     def find_most_confident_span(self, start_scores, end_scores):
#         ''' 
#         Inputs: masked start_scores and end_scores of a single example
#         Output: (i,j) pairs having highest Pr(i) + Pr(j)
#         '''
#         assert(len(start_scores) == len(end_scores))
#         best_start = 0
#         best_stop = 0
#         best_confidence = 0
#         for i in range(len(start_scores)):
#             for j in range(i, min(len(end_scores), i + self.max_answer_length)):
#                 if(start_scores[i] + end_scores[j] > best_confidence):
#                     best_start = i
#                     best_stop = j
# #                     best_confidence = start_scores[i] + end_scores[j]
#                     best_confidence = math.log(start_scores[i]) + math.log(end_scores[j])
#         return best_start, best_stop
    
    def find_most_confident_span(self, start_scores, end_scores):
        ''' 
        Inputs: masked start_scores and end_scores of a single example
        Output: (i,j) pairs having highest Pr(i) + Pr(j)
        '''
        assert(len(start_scores) == len(end_scores))
        best_start = 0
        best_stop = 0
        best_confidence = -1e100
        for i in range(len(start_scores)):
            for j in range(min(len(end_scores), i + self.max_answer_length)-1, i-1, -1):
                if(math.log(start_scores[i]) + math.log(end_scores[j]) > best_confidence):
                    best_start = i
                    best_stop = j
#                     best_confidence = start_scores[i] + end_scores[j]
                    best_confidence = math.log(start_scores[i]) + math.log(end_scores[j])
        return best_start, best_stop
    
    def find_top_n_confident_spans(self, start_scores, end_scores, n):
        ''' 
        Inputs: masked start_scores and end_scores of a single example
        Output: (i,j) n pairs having highest Pr(i) + Pr(j)
        '''
        assert(len(start_scores) == len(end_scores))
        best_start = 0
        best_stop = 0
        best_confidence = -1e100
        scores = []
        for i in range(len(start_scores)):
            for j in range(min(len(end_scores)-1, i + self.max_answer_length -1), i-1, -1):
                s = math.log(start_scores[i]) + math.log(end_scores[j])
                scores.append([s,i,j,start_scores[i],end_scores[j]])
        scores.sort(key = lambda x:x[0], reverse = True)
        return scores[:n]

    def format_prediction(self, yes_no_span, start_scores, end_scores, 
                          sequences, tokens_to_text_mappings, 
                          question_ids, max_question_len,official_evalutation=True):
        '''
        input: all numpy arrays
        output: {"question_id": answer_string}
        '''
        answers = {}
        assert(len(yes_no_span) == len(start_scores) == len(end_scores) == len(sequences))
        
        #TODO use range instead of trange
        for i in trange(len(yes_no_span)):
            if(official_evalutation):
                yns = yes_no_span[i].argmax(axis=-1)
                if(yns == 0):
                    answers[question_ids[i]] = "yes"
                    continue
                elif(yns == 1):
                    answers[question_ids[i]] = "no"
                    continue
            
            start, end = self.find_most_confident_span(start_scores[i], end_scores[i])
            
            sequence_chunks_concatenated = []
            for seq in sequences[i]:
                sequence_chunks_concatenated += seq[max_question_len + 2:]
            
            ans = self.convert_indices_to_text(sequence_chunks_concatenated, start, end, tokens_to_text_mappings[i])
            answers[question_ids[i]] = ans
        
        assert(len(answers) == len(sequences))

        return answers

In [5]:
class SupportingFactFormatter:
    '''
    inputs: 
    - a binary array for each question. It will have 1 if the corresponding sentence is a supporting fact 0 otherwise.
    - question id
    - names of paragraphs in the context
    - which paragraph is in which chunk
    - number of sentences in each paragraph
    
    output:
    A list like this
    [['Bridgeport, Connecticut', 5], ['Brookhaven National Laboratory', 1]]
    '''
    
    
    def __init__(self, num_chunks, num_sentences_per_chunk):
        self.num_chunks = num_chunks
        self.num_sentences_per_chunk = num_sentences_per_chunk
    
    def find_all_indices(self, the_array, the_value):
        assert(len(the_array.shape) == 1)
        return list(np.where(the_array == the_value)[0])
    
    def find_paragraph_and_sentence_index(self, sent_index, paragraph_chunk_indices, num_sentences_in_paragraphs):
        chunk_index = sent_index // self.num_sentences_per_chunk
        assert(chunk_index < self.num_chunks)
        sent_index = sent_index - (chunk_index * self.num_sentences_per_chunk)
        num_sents_cum_sum = 0
        para_index = -1
        actual_sentence_index = -1
        for p_index in paragraph_chunk_indices[chunk_index]:
            if(num_sents_cum_sum <= sent_index < num_sents_cum_sum + num_sentences_in_paragraphs[p_index]):
                para_index = p_index
                actual_sentence_index = sent_index - num_sents_cum_sum
                break
            else:
                num_sents_cum_sum += num_sentences_in_paragraphs[p_index]
        return para_index, actual_sentence_index
    
    def find_paragraph_name(self, para_index, paragraph_names):
        assert(0 <= para_index)
        return paragraph_names[para_index]
    
    def format_supporting_facts(self, predictions, question_ids, 
                                paragraph_names, paragraph_chunk_indices, 
                                num_sentences_in_paragraphs):
        assert( len(predictions) == len(question_ids) == len(paragraph_names) == len(paragraph_chunk_indices)
               == len(num_sentences_in_paragraphs) )
        
        
        out_records = {}
        for i, pred_row in enumerate(predictions):
            indices_of_sf = self.find_all_indices(the_array=pred_row, the_value=1)
            formatted_sf_list = []
            for sf_idx in indices_of_sf:
                para_idx, sentence_idx = self.find_paragraph_and_sentence_index(sent_index = sf_idx, 
                                                     paragraph_chunk_indices = paragraph_chunk_indices[i], 
                                                     num_sentences_in_paragraphs=num_sentences_in_paragraphs[i])
                if(para_idx < 0 or sentence_idx < 0):
                    continue
                para_name = self.find_paragraph_name(para_index=para_idx, paragraph_names=paragraph_names[i])
                formatted_sf_list.append([para_name, sentence_idx])
            out_records[question_ids[i]] = formatted_sf_list
            
        return out_records
                
    

In [6]:
sff = SupportingFactFormatter(4, 18)

In [7]:
sff.find_all_indices(the_array=np.array([1,2,1,2]), the_value=1)

[0, 2]

In [8]:
sff.find_paragraph_and_sentence_index(sent_index=3, 
                                      paragraph_chunk_indices=[[0,1,2],[3,6],[4,5],[]], 
                                      num_sentences_in_paragraphs=[3,2,5,4,1,3,6])

(1, 0)

In [9]:
a = np.array([[1,2,3],[4,5,6],[7,8,9]])

for i,row in enumerate(a):
    print(i , row)

0 [1 2 3]
1 [4 5 6]
2 [7 8 9]


In [10]:
class Evaluator:
    
    '''Adapted from the official evaluation script'''
    
    def normalize_answer(self, s):

        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))


    def f1_score(self, prediction, ground_truth):
        normalized_prediction = self.normalize_answer(prediction)
        normalized_ground_truth = self.normalize_answer(ground_truth)

        ZERO_METRIC = (0, 0, 0)

        if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
            return ZERO_METRIC
        if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
            return ZERO_METRIC

        prediction_tokens = normalized_prediction.split()
        ground_truth_tokens = normalized_ground_truth.split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return ZERO_METRIC
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1, precision, recall


    def exact_match_score(self, prediction, ground_truth):
        return (self.normalize_answer(prediction) == self.normalize_answer(ground_truth))

    def update_answer(self, metrics, prediction, gold):
        em = self.exact_match_score(prediction, gold)
        f1, prec, recall = self.f1_score(prediction, gold)
        metrics['em'] += float(em)
        metrics['f1'] += f1
        metrics['prec'] += prec
        metrics['recall'] += recall
        return em, prec, recall

    def update_sp(self, metrics, prediction, gold):
        cur_sp_pred = set(map(tuple, prediction))
        gold_sp_pred = set(map(tuple, gold))
        tp, fp, fn = 0, 0, 0
        for e in cur_sp_pred:
            if e in gold_sp_pred:
                tp += 1
            else:
                fp += 1
        for e in gold_sp_pred:
            if e not in cur_sp_pred:
                fn += 1
        prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
        em = 1.0 if fp + fn == 0 else 0.0
        metrics['sp_em'] += em
        metrics['sp_f1'] += f1
        metrics['sp_prec'] += prec
        metrics['sp_recall'] += recall
        return em, prec, recall

    def eval(self, prediction, gold):
        metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
            'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
            'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
        for dp in gold:
            cur_id = dp['_id']
            can_eval_joint = True
            if cur_id not in prediction['answer']:
                print('missing answer {}'.format(cur_id))
                can_eval_joint = False
            else:
                em, prec, recall = self.update_answer(
                    metrics, prediction['answer'][cur_id], dp['answer'])
            if cur_id not in prediction['sp']:
                print('missing sp fact {}'.format(cur_id))
                can_eval_joint = False
            else:
                sp_em, sp_prec, sp_recall = self.update_sp(
                    metrics, prediction['sp'][cur_id], dp['supporting_facts'])            
            
            if can_eval_joint:
                joint_prec = prec * sp_prec
                joint_recall = recall * sp_recall
                if joint_prec + joint_recall > 0:
                    joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
                else:
                    joint_f1 = 0.
                joint_em = em * sp_em

                metrics['joint_em'] += joint_em
                metrics['joint_f1'] += joint_f1
                metrics['joint_prec'] += joint_prec
                metrics['joint_recall'] += joint_recall

        N = len(gold)
        for k in metrics.keys():
            metrics[k] /= N

        print(metrics)


In [11]:
class AnswerPredictionMaker:
    
    def prepare_span_predictions(self, data):
        max_seq_len = (data['max_seq_len'] - data['max_question_len'] - 2) * data['num_chunks']
        out_start_predictions = []
        out_end_predictions = []
        small_number = 1e-10
        for i in range(len(data['answer_start_indices'])):

            if(len(data['answer_start_indices'][i]) == 0):
                start_indices = [small_number] * max_seq_len
                end_indices = [small_number] * max_seq_len
            else:
                start_indices = [small_number] * max_seq_len
                start_indices[data['answer_start_indices'][i][0]] = 1.0

                end_indices = [small_number] * max_seq_len
                end_indices[data['answer_end_indices'][i][0]] = 1.0

            out_start_predictions.append(start_indices)
            out_end_predictions.append(end_indices)

        return np.array(out_start_predictions), np.array(out_end_predictions)
    
    def prepare_yes_no_span_pred(self, data):
        out_yns = []
        small_number = 1e-10
        for i in range(len(data['yes_no_span'])):
            yns = [small_number,small_number,small_number]
            yns[data['yes_no_span'][i]] = 1
            out_yns.append(yns)
        return np.array(out_yns)
    
    
    def preprare_predictions(self, raw_data):
        start_predictions, end_predictions = self.prepare_span_predictions(raw_data)
        yes_no_span_pred = self.prepare_yes_no_span_pred(raw_data)
        
        return yes_no_span_pred, start_predictions, end_predictions

In [12]:
def prepare_gt_for_question_ids(data, question_ids):
    out_records = []
    for i, q_id in enumerate(question_ids):
        record = {}
        record['_id'] = q_id
        question_index = data['question_ids'].index(q_id)
        record['answer'] = data['answers_string'][question_index]
        record['supporting_facts'] = data['supporting_facts_raw'][question_index]
        out_records.append(record)
    return out_records


def prepare_gt_for_question_indices(data, question_indices=None):
    if(question_indices == None):
        question_indices = list(range(len(data['question_ids'])))
    out_records = []
    for i in question_indices:
        record = {}
        record['_id'] = data['question_ids'][i]
        record['answer'] = data['answers_string'][i]
        record['supporting_facts'] = data['supporting_facts_raw'][i]
        out_records.append(record)
    return out_records

In [13]:
# class SFPredictionMaker:
    
#     def preprare_predictions(self, raw_data):
#         out_list = []
        
#         for i in range(len(raw_data['question_indices'])):
#             sf_merged = []
#             for sf in raw_data['supporting_facts_expanded'][i]:
#                 sf_merged += sf
#             assert(len(sf_merged) == raw_data['num_chunks']*raw_data['max_num_sentences_per_chunk'])
#             out_list.append(sf_merged)
#         return np.array(out_list)

In [14]:
raw_data = unpickler('./', 'preprocessed_train.pkl')

In [15]:
raw_data.keys()

dict_keys(['question_context_sequences', 'segment_id', 'sentence_start_indices', 'sentence_end_indices', 'answer_start_indices', 'answer_end_indices', 'supporting_facts_expanded', 'question_ids', 'question_indices', 'yes_no_span', 'ids_to_word_mappings', 'max_seq_len', 'max_question_len', 'max_num_sentences_per_chunk', 'num_chunks', 'paragraph_chunk_indices', 'num_sentences_in_paragraphs', 'paragraph_names', 'answers_string', 'supporting_facts_raw'])

In [16]:
raw_data['answer_start_indices'][0]

[89, 94]

In [17]:
raw_data['max_seq_len'] - raw_data['max_question_len'] - 2

475

In [18]:
answer_prediction_maker = AnswerPredictionMaker()

In [19]:
evaluator = Evaluator()

In [20]:
predicted_span_formatter = PredictedSpanFormatter(max_answer_length=15)

In [21]:
sf_formatter = SupportingFactFormatter(num_chunks=4, num_sentences_per_chunk=18)

In [22]:
pred_yns, pred_start_scores, pred_end_scores = answer_prediction_maker.preprare_predictions(raw_data)

In [23]:
pred_sf = np.array(raw_data['supporting_facts_expanded'])

In [24]:
pred_sf.shape

(90447, 72)

In [25]:
pred_ans_str = predicted_span_formatter.format_prediction(pred_yns, pred_start_scores, pred_end_scores, 
                                                          sequences=raw_data['question_context_sequences'], 
                                                          tokens_to_text_mappings=raw_data['ids_to_word_mappings'], 
                                                          question_ids = raw_data['question_ids'],
                                                          max_question_len = raw_data['max_question_len'],
                                                          official_evalutation=True)

100%|██████████| 90447/90447 [27:11<00:00, 57.60it/s]


In [26]:
n=10
for i in range(n):
    q_id = raw_data['question_ids'][i]
    print("=================")
    print("Original: ", raw_data['answers_string'][i])
    print("Predicted: ", pred_ans_str[q_id])

Original:  Arthur's Magazine
Predicted:  Arthur's Magazine
Original:  Delhi
Predicted:  Delhi
Original:  President Richard Nixon
Predicted:  President Richard nixon
Original:  American
Predicted:  American
Original:  alcohol
Predicted:  alcohol
Original:  Jonathan Stark
Predicted:  Jonathan Stark
Original:  Crambidae
Predicted:  Crambidae
Original:  Badr Hari
Predicted:  Badr Hari
Original:  2006
Predicted:  2006
Original:  6.213 km long
Predicted:  6.213 km long


In [27]:
pred_sf_formatted = sf_formatter.format_supporting_facts(predictions = np.array(raw_data['supporting_facts_expanded']), 
                                                         question_ids = raw_data['question_ids'], 
                                                         paragraph_names = raw_data['paragraph_names'], 
                                                         paragraph_chunk_indices = raw_data['paragraph_chunk_indices'], 
                                                         num_sentences_in_paragraphs = raw_data['num_sentences_in_paragraphs'])

In [28]:
n=10
for i in range(n):
    q_id = raw_data['question_ids'][i]
    print("=================")
    print("Original: ", raw_data['supporting_facts_raw'][i])
    print("Predicted: ", pred_sf_formatted[q_id])

Original:  [["Arthur's Magazine", 0], ['First for Women', 0]]
Predicted:  [['First for Women', 0], ["Arthur's Magazine", 0]]
Original:  [['Oberoi family', 0], ['The Oberoi Group', 0]]
Predicted:  [['Oberoi family', 0], ['The Oberoi Group', 0]]
Original:  [['Allie Goertz', 0], ['Allie Goertz', 1], ['Allie Goertz', 2], ['Milhouse Van Houten', 0]]
Predicted:  [['Milhouse Van Houten', 0], ['Allie Goertz', 0], ['Allie Goertz', 1], ['Allie Goertz', 2]]
Original:  [['Peggy Seeger', 0], ['Peggy Seeger', 1], ['Ewan MacColl', 0]]
Predicted:  [['Ewan MacColl', 0], ['Peggy Seeger', 0], ['Peggy Seeger', 1]]
Original:  [['Cadmium chloride', 1], ['Ethanol', 0]]
Predicted:  [['Cadmium chloride', 1], ['Ethanol', 0]]
Original:  [['Jonathan Stark (tennis)', 0], ['Jonathan Stark (tennis)', 1], ['Henri Leconte', 1]]
Predicted:  [['Jonathan Stark (tennis)', 0], ['Jonathan Stark (tennis)', 1], ['Henri Leconte', 1]]
Original:  [['Indogrammodes', 0], ['Indogrammodes', 1], ['India', 0], ['India', 1]]
Predicted:

In [29]:
formatted_predictions = {'answer':pred_ans_str, 'sp':pred_sf_formatted}

In [30]:
# gt_sp = {}

# for i in range(len(raw_data['question_ids'])):
#     gt_sp[raw_data['question_ids'][i]] = raw_data['supporting_facts_raw'][i]

# formatted_predictions = {'answer':pred_ans_str, 'sp':gt_sp}

In [31]:
# massaged_gt = prepare_gt_for_question_ids(data=raw_data, question_ids = raw_data['question_ids'])
massaged_gt = prepare_gt_for_question_indices(data=raw_data, question_indices=None)

In [32]:
evaluator.eval(prediction=formatted_predictions, gold=massaged_gt)

{'em': 0.939677380123166, 'f1': 0.9573750549094912, 'prec': 0.956260543619155, 'recall': 0.9606650393543146, 'sp_em': 0.9935874047784891, 'sp_f1': 0.9981272026713048, 'sp_prec': 0.9999391909073823, 'sp_recall': 0.997172456166122, 'joint_em': 0.934912158501664, 'joint_f1': 0.9559789282733778, 'joint_prec': 0.9562135547748595, 'joint_recall': 0.9585402653848422}
