In [4]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [9]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
_ = model.eval()

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [13]:
dev_dict['data'][0]['story']

'Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer\'s horses slept. But Cotton wasn\'t alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton\'s mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer\'s orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. \n\n"What are you doing, Cotton?!" \n\n"I only wanted to be more like you". \n\nCotton\'s mommy rubbed her face on Cotton\'s and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want yo

In [61]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [62]:
doc=0
number = 1
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
first_answer = generate_first_answer(model, small_text)
answers = generate_multiple_answers(model, small_text)

KeyboardInterrupt: 

In [None]:
def get_description_from_data_item(item):
    return item['story']

def get_dialogue_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = ''
    text += ' '.join([q['input_text'] + ' ' + a['input_text'] + '.'
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '?'.join(text.split('?')[:-1]) + '?'
    return text

In [None]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]    
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [None]:
def select_answer(description, dialogue, first_answer, answers):
    return first_answer
    best_perplexity = 1e6
    
    current_dialogue = dialogue + ' ' + first_answer
    text = create_claim_from_description_and_dialogue(description, current_dialogue)
    y_or_n, perplexity = generate_answer(fact_checking_model, text)
    if y_or_n == 'Y':
        return first_answer, best_perplexity
    
    best_answer = ''
    for answer in answers:
        current_dialogue = dialogue + ' ' + answer
        text = create_claim_from_description_and_dialogue(description, current_dialogue)
        y_or_n, perplexity = generate_answer(fact_checking_model, text)
        if perplexity < best_perplexity and y_or_n == 'Y':
            best_perplexity = perplexity
            best_answer = answer
    if not best_answer:
        best_answer = first_answer
    return best_answer, best_perplexity

In [45]:
doc = 0
number = 6
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)

In [46]:
doc = 0
number = 0
description = get_description_from_data_item(dev_dict['data'][doc])
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
first_answer = generate_first_answer(model, small_text)
answers = generate_multiple_answers(model, small_text)
dialogue = get_dialogue_from_data_item(dev_dict['data'][doc],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=False)
print(answers)
select_answer(description, dialogue, first_answer, answers)

['white', 'White']


('White', 1000000.0)

# Computing accuracy after fact checking

In [53]:
def compute_original_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            #description = get_description_from_data_item(dev_dict['data'][index])
            #dialogue = get_dialogue_from_data_item(dev_dict['data'][index],
            #                           max_num_questions=5, 
            #                           question_number=number,
            #                           last_question=False)
            first_answer = generate_first_answer(model, small_text)
            #answers = generate_multiple_answers(model, small_text)
            prediction, _ = select_answer(description, dialogue, first_answer, [])
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [57]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            description = get_description_from_data_item(dev_dict['data'][index])
            dialogue = get_dialogue_from_data_item(dev_dict['data'][index],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=False)
            first_answer = generate_first_answer(model, small_text)
            answers = generate_multiple_answers(model, small_text)
            prediction, _ = select_answer(description, dialogue, first_answer, answers)
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [58]:
compute_accuracy_of_model(model)

100%|██████████| 10/10 [06:53<00:00, 41.40s/it]


(0.7021276595744681,
 [{'label': 'paint herself like them', 'prediction': 'she painted herself'},
  {'label': "the farmer's", 'prediction': 'her mommy'},
  {'label': 'the farmer', 'prediction': 'her mommy'},
  {'label': "the old farmer's", 'prediction': 'her mommy'},
  {'label': 'rubbed her face', 'prediction': 'laughed'},
  {'label': 'started laughing', 'prediction': 'laughed'},
  {'label': 'they started laughing', 'prediction': 'laughed'},
  {'label': 'dropped her into a big bucket of water',
   'prediction': 'a bucket of water'},
  {'label': 'a big bucket of water', 'prediction': 'a bucket of water'},
  {'label': 'the bottle', 'prediction': 'It was hard and clear'},
  {'label': 'a bottle', 'prediction': 'It was hard and clear'},
  {'label': 'No', 'prediction': 'Yes'},
  {'label': 'no', 'prediction': 'Yes'},
  {'label': "They took the note to Asta's papa",
   'prediction': "took it to Asta's papa"},
  {'label': 'unknown', 'prediction': "took it to Asta's papa"},
  {'label': "They too

# Results
Original accuracy: 0.7092198581560284

using save_fever2: 0.7092198581560284

using save_fever_with_qa_data_7: 0.7021276595744681

## Repeat the test with save number 7 <= there was an error with the evaluation function