# Goal: Take the preprocessed data, format it like the predictions and pass them to the evaluator

In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm, trange

import pdb

from pytorch_pretrained_bert import BertTokenizer

In [2]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [3]:
class PredictionFormatter:
    
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_answer_length = 20 #options.max_answer_length
        
    def un_tokenize(self,ids, tokens_to_text_mapping, bert_tokenizer):
        out_list = []
        start = 0
        end = start
        while (start < len(ids)) and (end < len(ids)):
            i = len(ids)
            decoded_anything = False
            while (decoded_anything == False) and (i > start):
                if(tuple(ids[start:i]) in tokens_to_text_mapping.keys()):
                    out_list.append(tokens_to_text_mapping[tuple(ids[start:i])])
                    decoded_anything = True
                else:
                    i -= 1
            if(decoded_anything == False):
                out_list.append(bert_tokenizer.convert_ids_to_tokens([ids[start]])[0])
                start += 1
                end = start
            else:
                start = i
                end = i
        return " ".join(out_list)

    def is_word_split(self,word):
        if(len(word) < 2):
            return False
        else:
            return (word[0] == '#' and word[1] == '#')

    def combine_word_pieces(self, sentence):
        # the first word cant start with '##'
        out_tokens = []
        for token in sentence:
            if(not self.is_word_split(token)):
                out_tokens.append((token))
            else:
                out_tokens[-1] += token[2:]
        return out_tokens
        
    def convert_indices_to_text(self, sentence, start, end, tokens_to_text_mapping):
        ''' (sentence, [10, 12]) --> ['runn', '##ing', 'race'] --> ['running', 'race']
        --> "running race" '''
        text = self.tokenizer.convert_ids_to_tokens(sentence)
        true_start = start
        if(self.is_word_split(text[start])):
            for i in range(1,start):
                if(not self.is_word_split(text[start-i])):
                    true_start = start-i
                    break
        
        true_end = end
        for i in range(end+1, len(sentence)):
            if(not self.is_word_split(text[i])):
                true_end = i-1
                break

        proper_text = self.un_tokenize(sentence[true_start:true_end+1], tokens_to_text_mapping, self.tokenizer)
        return proper_text
        

    def find_most_confident_span(self, start_scores, end_scores):
        ''' 
        Inputs: masked start_scores and end_scores of a single example
        Output: (i,j) pairs having highest Pr(i) + Pr(j)
        '''
        assert(len(start_scores) == len(end_scores))
        best_start = 0
        best_stop = 0
        best_confidence = 0
        for i in range(len(start_scores)):
            for j in range(i, min(len(end_scores), i + self.max_answer_length)):
                if(start_scores[i] + end_scores[j] > best_confidence):
                    best_start = i
                    best_stop = j
                    best_confidence = start_scores[i] + end_scores[j]
        return best_start, best_stop

    def format_prediction(self, yes_no_span, start_scores, end_scores, sequences, tokens_to_text_mappings, official_evalutation=True):
        '''
        input: all numpy arrays
        output: list of answer in english (ex: ["polar bear", "emperor penguin"])
        '''
        answers = []
        assert(len(yes_no_span) == len(start_scores) == len(end_scores) == len(sequences))

        for i in range(len(yes_no_span)):
            if(official_evalutation):
                yns = yes_no_span[i].argmax(axis=-1)
                if(yns == 0):
                    answers.append("yes")
                    continue
                elif(yns == 1):
                    answers.append("no")
                    continue
            
            start, end = self.find_most_confident_span(start_scores[i], end_scores[i])

            ans = self.convert_indices_to_text(sequences[i], start, end, tokens_to_text_mappings[i])
            answers.append(ans.replace(" ##",'').replace("##",''))
        
        assert(len(answers) == len(sequences))

        return answers


class Evaluator:
    
    ''' Adapted from the official HotpotQA evaluation script ''' 

    def normalize_answer(self, s):
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def f1_score(self,prediction, ground_truth):
        normalized_prediction = self.normalize_answer(prediction)
        normalized_ground_truth = self.normalize_answer(ground_truth)

        ZERO_METRIC = (0, 0, 0)

        if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
            return ZERO_METRIC
        if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
            return ZERO_METRIC

        prediction_tokens = normalized_prediction.split()
        ground_truth_tokens = normalized_ground_truth.split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return ZERO_METRIC
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1, precision, recall

    def exact_match_score(self, prediction, ground_truth):
        return (self.normalize_answer(prediction) == self.normalize_answer(ground_truth))

    def compute_yes_no_span_em(self, gt, pred):
        pred_classes = pred.argmax(axis=-1)
        num_correct = 0
        for i in range(len(pred)):
            if(pred_classes[i] == gt[i]):
                num_correct += 1
        return num_correct/len(pred)

    def evaluate_answers_official(self, gold, prediction, question_indices, divide_by_len_pred = True):
        assert(len(gold) >= len(prediction))
        assert(len(prediction) == len(question_indices))
        
        metrics = {'em': 0, 'f1':0, 'precision':0, 'recall':0}
        
        indices_with_0_em = []
        
        for i in range(len(prediction)):
            em = self.exact_match_score(prediction[i], gold[question_indices[i]])
            f1, prec, recall = self.f1_score(prediction[i], gold[question_indices[i]])
            metrics['em'] += float(em)
            metrics['f1'] += f1
            metrics['precision'] += prec
            metrics['recall'] += recall
            if(em == 0):
                indices_with_0_em.append(i)
        
        for key in metrics.keys():
            if(divide_by_len_pred):
                metrics[key] /= len(prediction)
            else:
                metrics[key] /= len(gold)
        
        return metrics, indices_with_0_em
    
    def individual_question_metrics(self, gold, prediction, question_indices):
        assert(len(gold) >= len(prediction))
        assert(len(prediction) == len(question_indices))
        
        metrics = {'em': [], 'f1':[], 'precision':[], 'recall':[]}
        
        for i in range(len(prediction)):
            em = self.exact_match_score(prediction[i], gold[question_indices[i]])
            f1, prec, recall = self.f1_score(prediction[i], gold[question_indices[i]])
            metrics['em'].append(float(em))
            metrics['f1'].append(f1)
            metrics['precision'].append(prec)
            metrics['recall'].append(recall)
        
        return metrics
    
    def evaluate_full_recall_only(self, gold, prediction, question_indices, full_recall_indices, divide_by_len_pred = True):
        
        all_metrics = self.individual_question_metrics(gold, prediction, question_indices)
        
        for i in range(len(prediction)):
            if i not in full_recall_indices:
                all_metrics['em'][i] = 0.0
                all_metrics['f1'][i] = 0.0
                all_metrics['precision'][i] = 0.0
                all_metrics['recall'][i] = 0.0
        
        final_metrics = {}
        for key in all_metrics.keys():
            final_metrics[key] = sum(all_metrics[key])/len(gold)
        
        return final_metrics

In [4]:
def prepare_span_predictions(data):
    max_seq_len = data['max_seq_len']
    out_start_predictions = []
    out_end_predictions = []
    for i in range(len(data['answer_start_indices_offset'])):
        
        if(len(data['answer_start_indices_offset'][i]) == 0):
            start_indices = [0] * max_seq_len
            end_indices = [0] * max_seq_len
        else:
            start_indices = [0] * max_seq_len
            start_indices[data['answer_start_indices_offset'][i][0]] = 1.0

            end_indices = [0] * max_seq_len
            end_indices[data['answer_end_indices_offset'][i][0]] = 1.0
        
        out_start_predictions.append(start_indices)
        out_end_predictions.append(end_indices)
    
    return np.array(out_start_predictions), np.array(out_end_predictions)

In [5]:
def prepare_yes_no_span_pred(data):
    out_yns = []
    for i in range(len(data['yes_no_span'])):
        yns = [0,0,0]
        yns[data['yes_no_span'][i]] = 1
        out_yns.append(yns)
    return np.array(out_yns)

In [6]:
raw_data = unpickler('./', 'preprocessed_train.pkl')

In [7]:
start_predictions, end_predictions = prepare_span_predictions(raw_data)

In [8]:
yes_no_span_pred = prepare_yes_no_span_pred(raw_data)

In [9]:
formatter = PredictionFormatter()
pred_answers = formatter.format_prediction(yes_no_span_pred, start_predictions, 
                                           end_predictions, raw_data['question_context_sequences'],
                                          raw_data["ids_to_word_mappings"])

In [10]:
evaluator = Evaluator()
metrics, indices_with_0_em = evaluator.evaluate_answers_official(gold=raw_data['answer_string'], prediction=pred_answers, 
                                     question_indices=raw_data['question_indices'], divide_by_len_pred = True)

In [11]:
metrics

{'em': 0.9834268823275433,
 'f1': 0.9864094607896177,
 'precision': 0.9883172798223105,
 'recall': 0.9866360440195813}

In [12]:
len(indices_with_0_em)

1497

In [13]:
Counter([pred_answers[i] for i in indices_with_0_em])

Counter({'6 , 960': 1,
         'Opéra': 1,
         'Washington, d . c .': 10,
         '167 , 446': 1,
         'i - 90': 1,
         'quinceanera': 1,
         '3 , 544': 2,
         'Robert': 1,
         '39 , 134': 2,
         '8 , 765 , 000': 1,
         'banjo': 1,
         "the Go - Go ' s": 1,
         '13 - 18': 1,
         'Best Young actor / Actress': 1,
         'Na do - hyang': 1,
         'American': 1,
         'founded': 1,
         'Primera division': 3,
         "defence of fort m ' henry": 1,
         '1 , 425': 1,
         'Benares gharana': 1,
         'Juan Nepomuceno Seguin': 1,
         '8 , 711': 1,
         'Shim hyung - rae': 1,
         'traditionalist': 1,
         'Gelsenkirchen, North rhine - westphalia': 1,
         '15 , 023': 3,
         '66 , 135': 1,
         '124 , 775': 2,
         'yu - gi - oh': 1,
         '542 , 196': 1,
         '2 , 416': 2,
         'the': 37,
         'two': 1,
         '7 , 304': 3,
         'Comedy': 1,
         '87 , 64

In [14]:
indices_with_0_em[:10]

[190, 262, 346, 384, 389, 410, 425, 443, 507, 513]

In [15]:
idx = 28
print("True answer: {}".format(raw_data['answer_string'][raw_data['question_indices'][idx]]))
print("Predicted: {}".format(pred_answers[idx]))

True answer: no
Predicted: no


In [16]:
num_yes = 0
num_no = 0
num_span = 0
for i in raw_data["yes_no_span"]:
    if(i==0):
        num_yes += 1
    elif(i==1):
        num_no += 1
    else:
        num_span += 1

In [17]:
print("Num yes: ", num_yes)
print("Num no: ", num_no)
print("Num span: ", num_span)
print("Num yes + num no: ", num_yes+num_no)

Num yes:  2748
Num no:  2733
Num span:  84846
Num yes + num no:  5481


In [18]:
print("Fraction of yes/no questions in data: ",(num_yes+num_no)/len(raw_data['yes_no_span']))

Fraction of yes/no questions in data:  0.06067953103723139
