## Validation scripts

To validate our models performace we need a consistant way of assessing how it's performing. 

We will setup two different versions uning top 1 and top k predictions.

**Method: 1 - On the processed encoding set**
Here we basically just run the test set through the model again extract the outputs and evaluate

**Method: 2 - On the full text of test set**
Here we again operate on the test set, but now on the full text. We will apply a sliding window over the text with a configurable parameter of token window. This will allow us to configure and experiment with a vareity of element related to what consitutes good performance. Everythin except the predictions containing the answer will be considere ```is_impossible=True``` and sections only containing parts of the answer has not been considered for evaluation

**Method 1**

In [1]:
import sys
sys.path.append('../')

In [144]:
from datasets import load_dataset
from data import CUADDataset
from models import QAModelBert
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast,AutoTokenizer
import pandas as pd
import torch

In [145]:
global hparams
hparams = {
    'lr': 9.5e-6,
    'batch_size': 8,
    'num_workers': 5,
    'num_labels': 2,
    'hidden_size': 768,
    'num_train_epochs': 6,
    'bert_model': 'bert-base-uncased',
    'log_text_every_n_batch': 30,
    'log_text_every_n_batch_valid': 10
}

if True:
    model = torch.load('model.model')
    hparams=model.hparams
else:
    model = QAModelBert(hparams, hparams['bert_model'])
tokenizer = AutoTokenizer.from_pretrained(
    hparams['bert_model'])

In [146]:
val_encodings = torch.load("../cuad_training/data/test_encodings")
val_dataset = CUADDataset(val_encodings)
val_loader = DataLoader(
    val_dataset, batch_size=hparams.get("batch_size"), shuffle=False)
del val_encodings

In [147]:
res = []
for batch in val_loader:
    # Run nada
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    start_positions = batch['start_positions']
    end_positions = batch['end_positions']
    outputs = model(input_ids, attention_mask=attention_mask,
                         start_positions=start_positions, end_positions=end_positions)
    res.append({'batch':batch, 'outputs':outputs})
    break

In [148]:
import numpy as np
def decode(
        start: torch.tensor, end: torch.tensor, topk: int, max_answer_len: int
    ):
        """
        Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the
        actual answer.
        In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or
        answer end position being before the starting position. The method supports output the k-best answer through
        the topk argument.
        Args:
            start (`torch.tensor`): Individual start probabilities for each token.
            end (`torch.tensor`): Individual end probabilities for each token.
            topk (`int`): Indicates how many possible answer span(s) to extract from the model output.
            max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
        """
        # Ensure we have batch axis
        if start.ndim == 1:
            start = start[None]

        if end.ndim == 1:
            end = end[None]
        
        # Compute the score of each tuple(start, end) to be the real answer
        outer = torch.matmul(start.unsqueeze(-1), end.unsqueeze(1))

        # Remove candidate with end < start and end - start > max_answer_len
        candidates = torch.tril(torch.triu(outer), max_answer_len - 1)
        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
        scores_flat = candidates.flatten()
        # Get nr. 1
        if topk == 1:
            idx_sort = [torch.argmax(scores_flat)]
        elif len(scores_flat) < topk:
            idx_sort = torch.argsort(-scores_flat)
        else:
            idx = torch.topk(scores_flat, topk).indices
            idx_sort = idx[torch.argsort(-scores_flat[idx])]
        def unravel_index(index, shape):
            out = []
            for dim in reversed(shape):
                out.append(index % dim)
                index = torch.div(index, dim, rounding_mode='trunc')
            return tuple(reversed(out))
        starts, ends = unravel_index(idx_sort, candidates.shape)[1:]
        scores = candidates[:, starts, ends]

        return starts, ends, scores

In [149]:
def get_pred_from_batch_outputs(batch, start_logits, end_logits, tokenizer, top_k=2, max_ans_len=200):
    """Takes as input the batch and the outputs of the model and returns the predictions

    Args:
        batch (_type_): Batch from dataloader
        start_logits (_type_): start logits from the model
        end_logits (_type_): end logits from the model
        tokenizer (_type_): To decode the tokens
        top_k (int, optional): how many predictions to return. Defaults to 2.
        max_ans_len (int, optional): limit answers to specific lenght. Defaults to 200.

    Returns:
        List[List[Tuple]]: Batch x Top k predictions x (id, idx_top_k , is_impossible ,_pred_text,_answer_text,confidence, _start, _end)
    """

    # Get logits
    start_ = start_logits.detach()
    end_ = end_logits.detach()
    # Normalise
    start_ = torch.exp(
        start_ - torch.log(torch.sum(torch.exp(start_), axis=-1, keepdims=True)))
    end_ = torch.exp(end_ - torch.log(torch.sum(torch.exp(end_), axis=-1, keepdims=True)))

    # List of batch predictions
    p = [decode(x, y, top_k, max_ans_len) for x, y in zip(start_, end_)]

    predictions = []

    i = 0
    # Batch results
    for ele in p:
        # Top k
        temp_collect = []
        # Each single predictions
        for idx, _v in enumerate(zip(ele[0], ele[1])):
            _start, _end = _v
            _pred = batch['input_ids'][i][_start:_end+1]
            _id = batch['id'][i]
            _answer_start = batch['start_positions'][i]
            _answer_end = batch['end_positions'][i]

            _answer = batch['input_ids'][i][_answer_start:_answer_end+1]

            confidence = ele[2][0][idx]
            is_impossible = batch['is_impossible'][i].numpy()

            # get answer and pred text
            if (_start == _end == 1):
                _pred_text = ""
            else:
                _pred_text = tokenizer.decode(_pred)

            if (_answer_start == _answer_end == 1) or (is_impossible):
                _answer_text = ""
            else:
                _answer_text = tokenizer.decode(_answer)

            temp_collect.append((_id, idx, is_impossible,
                                _pred_text, _answer_text, confidence, _start, _end))
        predictions.append(temp_collect)
        i += 1
    return predictions

In [150]:
collected_results = []
# list of batches
for pred in res:
    collected_results+=get_pred_from_batch_outputs(pred['batch'],pred['outputs'][1],pred['outputs'][2],tokenizer)

In [151]:
collected_results[3]

[(tensor(5),
  0,
  array(False),
  'the contract is valid for 5 years, beginning from and ended on',
  'the contract is valid for 5 years, beginning from and ended on.',
  tensor(0.7007),
  tensor(364),
  tensor(376)),
 (tensor(5),
  1,
  array(False),
  'the contract is valid for 5 years, beginning from and ended on.',
  'the contract is valid for 5 years, beginning from and ended on.',
  tensor(0.2896),
  tensor(364),
  tensor(377))]

In [143]:
collected_results[3]

[(tensor(5),
  0,
  array(False),
  'the contract is valid for 5 years, beginning from and ended on.',
  'the contract is valid for 5 years, beginning from and ended on.',
  tensor(0.9370),
  tensor(364),
  tensor(377)),
 (tensor(5),
  1,
  array(False),
  'the contract is valid for 5 years, beginning from and ended on',
  'the contract is valid for 5 years, beginning from and ended on.',
  tensor(0.0049),
  tensor(364),
  tensor(376))]

In [30]:
import re
import string
from collections import Counter
import torch
import copy

def get_jaccard(gt, pred):
    remove_tokens = [".", ",", ";", ":"]
    for token in remove_tokens:
        gt = gt.replace(token, "")
        pred = pred.replace(token, "")
    gt = gt.lower()
    pred = pred.lower()
    gt = gt.replace("/", " ")
    pred = pred.replace("/", " ")

    gt_words = set(gt.split(" "))
    pred_words = set(pred.split(" "))

    intersection = gt_words.intersection(pred_words)
    union = gt_words.union(pred_words)
    jaccard = len(intersection) / len(union)
    return jaccard
def normalize_answer(s):
    '''
    Performs a series of cleaning steps on the ground truth and 
    predicted answer.
    '''
    s = copy.deepcopy(s)

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s)))).replace(" ", '')

In [63]:
def compute_top_1_scores_from_preds(predictions, iou_threshold=0.5, include_df=False):
    # Compute precision, recall, f1, em - best
    df = []
    df_struct = {"id":None,
                "question":None,
                 "is_impossible":None,
                "pred":None,
                "answer":None,
                "jaccard":None,
                "doc_length":None,
                "em": None,
                "tp":None,
                "fp":None,
                "tn":None,
                "fn":None,
                "top_k_idx":None,
                "confidence":None,
                "start_token_loc_pred": None,
                "end_token_loc_pred": None}
    IOU_THRESH=iou_threshold
    
    # Text ans pred text, imp ans pred text, text ans pred imp, imp ans pred imp
    tp, fp, fn, tn = 0, 0, 0, 0
    # exact match
    em=0
    for prediction_set in predictions:
        
        # Currently only top 1
        _id, top_k_idx, is_impossible,_pred_text,answer,confidence, start_loc_pred, end_loc_pred = prediction_set[0]
        _tmp_df = copy.copy(df_struct)
        
        # Set values
        _tmp_df["is_impossible"]= is_impossible
        _tmp_df["top_k_idx"]= 0
        _tmp_df["id"]= _id
        _tmp_df["answer"]= answer
        _tmp_df["pred"]= _pred_text
        _tmp_df["confidence"]= confidence
        _tmp_df["start_token_loc_pred"]= start_loc_pred
        _tmp_df["end_token_loc_pred"]= end_loc_pred
        
        
        if is_impossible:
            if len(_pred_text)>0:
                _tmp_df['fp']=True
                fp+=1
            else:
                tn+=1
                _tmp_df['tn']=True
        else:
            if len(_pred_text)<1:
                fn+=1
                _tmp_df['fn']=True
            else:
                jc = get_jaccard(answer, _pred_text)
                _tmp_df['jaccard']=jc
                if normalize_answer(answer)==normalize_answer(_pred_text):
                    _tmp_df['em']=True
                    _tmp_df['tp']=True
                    em+=1
                    tp+=1
                elif jc>=IOU_THRESH:
                    _tmp_df['tp']=True
                    tp+=1
                else:
                    _tmp_df['fp']=True
                    fp+=1
        df.append(_tmp_df)
    precision = (tp+tn) / (tp+tn + fp) if tp+tn + fp > 0 else np.nan
    recall = (tp+tn) / (tp+tn + fn) if tp+tn + fn > 0 else np.nan
    
    
    res = {"tp":tp,
          "fp":fp,
          "fn":fn,
          "tn":tn,
           "em":em,
          "precision":precision,
          "recall":recall,
          "count":len(predictions),
          "f1":2*(precision*recall)/(precision+recall)}
    if include_df:
        return res, df
    return res
    

In [64]:
res2 , df2 =compute_top_1_scores_from_preds(collected_results, include_df=True)

In [65]:
pd.DataFrame(df2)

Unnamed: 0,id,question,is_impossible,pred,answer,jaccard,doc_length,em,tp,fp,tn,fn,top_k_idx,confidence,start_token_loc_pred,end_token_loc_pred
0,1,,False,supply contract contract,supply contract,1.0,,,True,,,,0,0.676836,5,7
1,3,,True,,,,,,,,True,,0,0.999995,1,1
2,4,,True,,,,,,,,True,,0,0.055753,1,1
3,5,,False,"the contract is valid for 5 years, beginning f...","the contract is valid for 5 years, beginning f...",1.0,,True,True,,,,0,0.937044,364,377
4,6,,True,,,,,,,,True,,0,0.998718,1,1
5,7,,True,,,,,,,,True,,0,0.999959,1,1
6,8,,False,it will be governed by the law of the people's...,it will be governed by the law of the people's...,1.0,,True,True,,,,0,0.959169,293,325
7,9,,True,,,,,,,,True,,0,0.999996,1,1


In [66]:
res2

{'tp': 3,
 'fp': 0,
 'fn': 0,
 'tn': 5,
 'em': 2,
 'precision': 1.0,
 'recall': 1.0,
 'count': 8,
 'f1': 1.0}

In [61]:
print(f"Precision: {res2.get('precision')}")
print(f"Recall: {res2.get('recall')}")
print(f"EM: {res2.get('em')/res2.get('count')}")
print(f"F1: {res2.get('f1')}")

Precision: 3.625
Recall: 3.625
EM: 0.25
F1: 3.625


In [None]:
import pandas as pd

In [36]:
import json
with open('../cuad_training/data/test_data.json', 'r') as f:
    dat = f.read()
data = json.loads(dat)
del dat

### Encodings vs original data

In [46]:
pd.DataFrame(data['data'][0:10])

Unnamed: 0,start_positions,end_positions,question,context,id,original_id,char_span_start,is_impossible,answer,title
0,14,29,Highlight the parts (if any) of this contract ...,Exhibit 10.16 SUPPLY CONTRACT Contract No: Dat...,1,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,0,False,SUPPLY CONTRACT,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
1,1,1,Highlight the parts (if any) of this contract ...,ccording to specific order by YICHANGTAI or LE...,3,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,3471,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
2,1,1,Highlight the parts (if any) of this contract ...,d sellers authorized representative signature ...,4,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,702,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
3,1816,1880,Highlight the parts (if any) of this contract ...,", shall not exceed 5% of the total value of th...",5,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,9169,False,"The Contract is valid for 5 years, beginning f...",LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
4,1,1,Highlight the parts (if any) of this contract ...,"force of law.\n\n1\n\nSource: LOHA CO. LTD., F...",6,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,2008,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
5,1,1,Highlight the parts (if any) of this contract ...,"he contract, in favor of the Seller, for 100% ...",7,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,3911,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
6,1510,1678,Highlight the parts (if any) of this contract ...,exceed 5% of the total value of the goods invo...,8,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,9181,False,It will be governed by the law of the People's...,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
7,1,1,Highlight the parts (if any) of this contract ...,ed. Under the T/T The trustee of the buyer rem...,9,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,4466,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
8,1,1,Highlight the parts (if any) of this contract ...,of Credit shall be valid until 90 days after t...,10,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,4392,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
9,1,1,Highlight the parts (if any) of this contract ...,he named place of destination and bear all ris...,11,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...,5478,True,,LohaCompanyltd_20191209_F-1_EX-10.16_11917878_...
