# Evaluate predictions on HotpotQA
- Model predicts weather a sentence is a supporting fact to answer a question
- This notebook rearranges the predictions and evaluates the performance just like the hotpot evaluation script

In [1]:
import random
import math
import os
import pickle
from tqdm import tqdm, trange
import numpy as np

from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import precision_score, recall_score

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pdb

In [2]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [3]:
data_pkl_path = "../../data/hotpot/"
data_pkl_name = "preprocessed_dev.pkl"
predictions_pkl_path = "./"
predictions_pkl_name = "predictions.pkl"

In [4]:
gt = [[0,0],[0,1],[1,0],[1,1]]
pred = [[1,1],[1,1],[1,0],[1,1]]

In [5]:
def exact_match(gt, pred):
    assert(len(gt) == len(pred))
    total_size = len(pred)
    num_correct = 0
    for i in range(total_size):
        if(gt[i] == pred[i]):
            num_correct += 1
    return num_correct/total_size

In [6]:
exact_match(gt, pred)

0.5

In [7]:
def evaluate(gt, pred):
    assert(len(gt) == len(pred))
    total_size = len(pred)
    assert(len(gt) != 0)
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    total_correct = 0
    for i in trange(total_size):
        if(gt[i] == pred[i]):
            total_correct += 1
        p = precision_score(gt[i], pred[i],average="binary")
        r = recall_score(gt[i], pred[i],average="binary")
        total_precision += p
        total_recall += r
        total_f1 += 2*(p*r)/(p+r) if (p+r)>0 else 0
    return {"precision":total_precision/total_size, "recall":total_recall/total_size, 
            "f1":total_f1/total_size, "em":total_correct/total_size}

In [8]:
evaluate(gt, pred)

  'recall', 'true', average, warn_for)
100%|██████████| 4/4 [00:00<00:00, 87.46it/s]


{'precision': 0.625, 'recall': 0.75, 'f1': 0.6666666666666666, 'em': 0.5}

In [9]:
def reorganize_predictions(predictions, document_lengths):
    out_list = []
    start_index = 0
    for i in range(len(document_lengths)):
        p = predictions[start_index:start_index+document_lengths[i]]
        out_list.append(p)
        start_index += document_lengths[i]
    return out_list

In [10]:
reorganize_predictions(list(range(20)), [1,19])

[[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]

In [11]:
dataset = unpickler(data_pkl_path, data_pkl_name)

In [12]:
dataset.keys()

dict_keys(['sequences', 'segment_ids', 'supporting_fact', 'document_lengths'])

In [13]:
predictions_raw = unpickler(predictions_pkl_path, predictions_pkl_name)

In [14]:
type(predictions_raw)

numpy.ndarray

In [15]:
predictions_raw.shape

(306423,)

In [16]:
predictions_raw[:20]

array([ -2.703494 ,  -7.3818617,  -9.597219 ,   4.442234 ,  -4.757165 ,
        -2.452815 ,  -5.1548805, -11.324129 , -10.174879 ,  -7.2726355,
        -8.698872 ,  -4.4058604,  -2.4269946,  -6.133337 ,  -5.7811146,
        -3.2419255,   3.4809365,  -9.175812 ,  -9.545071 ,  -7.4124694],
      dtype=float32)

In [17]:
predictions_raw.min()

-18.612978

In [18]:
predictions_raw.max()

9.374624

In [59]:
pred_answer_labels = (torch.sigmoid(torch.tensor(predictions_raw)) > 0.9).numpy().tolist()

In [60]:
sum(pred_answer_labels)

20152

In [61]:
pred_answer_labels_reorganized = reorganize_predictions(pred_answer_labels, dataset["document_lengths"])

In [62]:
len(pred_answer_labels_reorganized)

7404

In [63]:
gt_reorganized = reorganize_predictions(dataset["supporting_fact"], dataset["document_lengths"])

In [64]:
len(gt_reorganized)

7404

### Are the lengths same ?

In [65]:
for i in range(len(gt_reorganized)):
    assert(len(gt_reorganized[i]) == len(pred_answer_labels_reorganized[i]))

### Evaluation

In [66]:
sum(dataset["supporting_fact"])

18001

In [67]:
sum(pred_answer_labels)

20152

In [68]:
evaluate(gt_reorganized, pred_answer_labels_reorganized)

100%|██████████| 7404/7404 [00:13<00:00, 557.12it/s]


{'precision': 0.6867104762921196,
 'recall': 0.7088041213243822,
 'f1': 0.6588673467117654,
 'em': 0.228795245813074}