# GPT3 CUAD

The processing of the CUAD dataset with the GPT3 Q&A engine

In [4]:
import os
from src.models.drqa.drqa_dataset import *

**Get and prepare data**

In [33]:
g = torch.Generator()
valid_dataset = DRQADataset("../data/processed/squad_drqa/drqa_valid.pkl")
valid_dataset.frame = valid_dataset.frame.drop_duplicates(subset=['id'])
print(len(valid_dataset))
validloader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=10,
    shuffle=True,
    generator=g,
)

5903


**Fetch data**

In [58]:
import openai
from tqdm import tqdm

openai.api_key = os.getenv('OPENAI_APIKEY')

valid_dataset.frame['res']=''
for idx, row in tqdm(valid_dataset.frame.iterrows()):
    
    try:
        response = openai.Answer.create(
         search_model="ada",
         model="curie",
         question=row.question,
         documents=[row.context],
         examples_context="In 2017, U.S. life expectancy was 78.6 years.",
         examples=[["What is human life expectancy in the United States?","78 years."]],
         max_tokens=len(row.answer.split(' '))*3,
         stop=["\n", "<|endoftext|>"],
        )
        row.res=response.to_dict()
    except Exception as e:
        row.res={'error':True, 'ags': e.args}    


5903it [33:04,  2.98it/s]


Process response

In [97]:
valid_dataset.frame['res_ans']=valid_dataset.frame.res.apply(lambda x: x.get('answers',[' '])[0].replace('.',''))
valid_dataset.frame['res_ans']=valid_dataset.frame.res.apply(lambda x: x.get('answers',[' '])[0].replace('.','').lower().replace('the ',''))

Save data

In [82]:
valid_dataset.frame.to_pickle('./GPT3SQuAD_data.pkl')

Calculate performance

In [98]:
import json
import json
import os
import re
import string
from collections import Counter

def evaluate(predictions, **kwargs):
    '''
    Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1). 
    This method first parses the JSON file, gets all the answers
    for a given id and then passes the list of answers and the 
    predictions to calculate em, f1.


    :param dict predictions
    Returns
    : exact_match: 1 if the prediction and ground truth 
      match exactly, 0 otherwise.
    : f1_score: 
    '''

    # TODO: Change to correct directory
    with open('../data/raw/dev-v1.1.json', 'r', encoding='utf-8') as f:
        dataset = json.load(f)

    dataset = dataset['data']
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue

                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                
                prediction = predictions[qa['id']]

                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)

                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return exact_match, f1


def evaluate_single(predictions, answers, **kwargs):
    '''
    Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1). 
    This method first parses the JSON file, gets all the answers
    for a given id and then passes the list of answers and the 
    predictions to calculate em, f1.


    :param dict predictions
    Returns
    : exact_match: 1 if the prediction and ground truth 
      match exactly, 0 otherwise.
    : f1_score: 
    '''
    assert len(predictions) == len(answers)
    f1 = exact_match = total = 0
    for key, value in predictions.items():
        prediction = value
        ground_truths = [answers[key]]

        exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)

    total = len(predictions)
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return exact_match, f1


def normalize_answer(s):
    '''
    Performs a series of cleaning steps on the ground truth and 
    predicted answer.
    '''
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    '''
    Returns maximum value of metrics for predicition by model against
    multiple ground truths.

    :param func metric_fn: can be 'exact_match_score' or 'f1_score'
    :param str prediction: predicted answer span by the model
    :param list ground_truths: list of ground truths against which
                               metrics are calculated. Maximum values of 
                               metrics are chosen.


    '''
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)

    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    '''
    Returns f1 score of two strings.
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    '''
    Returns exact_match_score of two strings.
    '''
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def epoch_time(start_time, end_time):
    '''
    Helper function to record epoch time.
    '''
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


In [99]:
evaluate({x.id: x.res_ans for idx, x in valid_dataset.frame.iterrows()})

(25.10879848628193, 34.42828376790898)