In [8]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [9]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [10]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
fact_checking_model = GPT2LMHeadModel.from_pretrained('gpt2')
fact_checking_model.cuda()
checkpoint = torch.load('save_fever_with_qa_data_13')
fact_checking_model.load_state_dict(checkpoint['model_state_dict'])
_ = fact_checking_model.eval()

In [11]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [12]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [13]:
def generate_answer(fact_checking_model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = fact_checking_model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    to_return = tokenizer.decode(output[0], skip_special_tokens=True)
    perplexity = float(model(output, labels=output)[0])
    return get_answer_from_text(to_return), perplexity

In [14]:
def get_best_answer(fact_checking_model, text):
    prompt_y = get_text_up_to_question(text) + 'Y'
    prompt_n = get_text_up_to_question(text) + 'N'
    tokens_y = tokenizer.encode(prompt_y, return_tensors='pt').cuda()
    tokens_n = tokenizer.encode(prompt_n, return_tensors='pt').cuda()
    perplexity_y = float(model(tokens_y, labels=tokens_y)[0])
    perplexity_n = float(model(tokens_n, labels=tokens_n)[0])
    if perplexity_y < perplexity_n:
        return 'Y', perplexity_y
    return 'N', perplexity_n

# Question Answering part

In [15]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [16]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [17]:
def generate_multiple_answers(model, prompt, num_replicas=5):
    model.train()
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    outputs = []
    for _ in range(num_replicas):
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        output = tokenizer.decode(output[0], skip_special_tokens=True)
        offset = len(prompt)
        start = offset + 1
        end = output.find('\n', start)
        outputs.append(output[start:end].split(':')[-1].strip())
        
    return list(set(outputs))

In [18]:
def generate_first_answer(model, prompt):
    model.eval()
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [19]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [20]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [21]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [22]:
doc=0
number = 1
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
first_answer = generate_first_answer(model, small_text)
answers = generate_multiple_answers(model, small_text)

In [23]:
def get_description_from_data_item(item):
    return item['story']

def get_dialogue_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = ''
    text += ' '.join([q['input_text'] + ' ' + a['input_text'] + '.'
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '?'.join(text.split('?')[:-1]) + '?'
    return text

In [24]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]    
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [31]:
def select_answer(description, dialogue, first_answer, answers):
    best_perplexity = 1e6
    
    #current_dialogue = dialogue + ' ' + first_answer
    #text = create_claim_from_description_and_dialogue(description, current_dialogue)
    #y_or_n, perplexity = generate_answer(fact_checking_model, text)
    #if y_or_n == 'Y':
    #    return first_answer, best_perplexity
    
    best_answer = ''
    for answer in answers:
        current_dialogue = dialogue + ' ' + answer
        text = create_claim_from_description_and_dialogue(description, current_dialogue)
        y_or_n, perplexity = generate_answer(fact_checking_model, text)
        if perplexity < best_perplexity and y_or_n == 'Y':
            best_perplexity = perplexity
            best_answer = answer
#    if not best_answer:
#        best_perplexity = 0
#        for answer in answers:
#            current_dialogue = dialogue + ' ' + answer
#            text = create_claim_from_description_and_dialogue(description, current_dialogue)
#            y_or_n, perplexity = generate_answer(fact_checking_model, text)
#            if perplexity > best_perplexity and y_or_n == 'N':
#                best_perplexity = perplexity
#                best_answer = answer            
    if not best_answer:
        best_answer = first_answer
    return best_answer, best_perplexity

In [32]:
doc = 0
number = 6
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)

In [33]:
doc = 0
number = 0
description = get_description_from_data_item(dev_dict['data'][doc])
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
first_answer = generate_first_answer(model, small_text)
answers = generate_multiple_answers(model, small_text)
dialogue = get_dialogue_from_data_item(dev_dict['data'][doc],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=False)
print(answers)
select_answer(description, dialogue, first_answer, answers)

['white', 'White', 'orange']


('orange', 3.2131409645080566)

# Computing accuracy after fact checking

In [34]:
def compute_original_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            #description = get_description_from_data_item(dev_dict['data'][index])
            #dialogue = get_dialogue_from_data_item(dev_dict['data'][index],
            #                           max_num_questions=5, 
            #                           question_number=number,
            #                           last_question=False)
            first_answer = generate_first_answer(model, small_text)
            #answers = generate_multiple_answers(model, small_text)
            prediction, _ = select_answer(description, dialogue, first_answer, [])
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [35]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            description = get_description_from_data_item(dev_dict['data'][index])
            dialogue = get_dialogue_from_data_item(dev_dict['data'][index],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=False)
            first_answer = generate_first_answer(model, small_text)
            answers = generate_multiple_answers(model, small_text)
            prediction, _ = select_answer(description, dialogue, first_answer, answers)
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [36]:
compute_accuracy_of_model(model)

100%|██████████| 10/10 [06:42<00:00, 40.22s/it]


(0.6099290780141844,
 [{'label': 'with her mommy and 5 sisters',
   'prediction': 'her mommy and 5 other sisters'},
  {'label': 'paint herself like them', 'prediction': 'she painted herself'},
  {'label': 'the farmer', 'prediction': "Cotton's mommy's"},
  {'label': "the old farmer's", 'prediction': "Cotton's mommy's"},
  {'label': "the farmer's", 'prediction': "Cotton's mommy's"},
  {'label': 'started laughing', 'prediction': 'They laughed'},
  {'label': 'they started laughing', 'prediction': 'They laughed'},
  {'label': 'rubbed her face', 'prediction': 'They laughed'},
  {'label': 'dropped her into a big bucket of water',
   'prediction': 'a bucket of water'},
  {'label': 'a big bucket of water', 'prediction': 'a bucket of water'},
  {'label': 'into a big bucket of water', 'prediction': 'a bucket of water'},
  {'label': 'No', 'prediction': 'yes'},
  {'label': 'no', 'prediction': 'yes'},
  {'label': 'the bottle', 'prediction': "It looked like a bird's belly"},
  {'label': 'a bottle', '

# Results
Original accuracy: 0.7092198581560284

using save_fever2: 0.7092198581560284

using save_fever_with_qa_data_7: 0.7021276595744681

## Repeat the test with save number 7 <= there was an error with the evaluation function