In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
fact_checking_model = GPT2LMHeadModel.from_pretrained('gpt2')
fact_checking_model.cuda()
checkpoint = torch.load('save_fever3')
fact_checking_model.load_state_dict(checkpoint['model_state_dict'])
_ = fact_checking_model.eval()

In [4]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [5]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [6]:
def generate_answer(fact_checking_model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = fact_checking_model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    to_return = tokenizer.decode(output[0], skip_special_tokens=True)
    perplexity = float(model(output, labels=output)[0])
    return get_answer_from_text(to_return), perplexity

In [7]:
def get_best_answer(fact_checking_model, text):
    prompt_y = get_text_up_to_question(text) + 'Y'
    prompt_n = get_text_up_to_question(text) + 'N'
    tokens_y = tokenizer.encode(prompt_y, return_tensors='pt').cuda()
    tokens_n = tokenizer.encode(prompt_n, return_tensors='pt').cuda()
    perplexity_y = float(model(tokens_y, labels=tokens_y)[0])
    perplexity_n = float(model(tokens_n, labels=tokens_n)[0])
    if perplexity_y < perplexity_n:
        return 'Y', perplexity_y
    return 'N', perplexity_n

# Question Answering part

In [8]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [9]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [10]:
def generate_multiple_answers(model, prompt, num_replicas=25):
    model.train()
    outputs = []
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        tokens = tokens.repeat(num_replicas,1)
        _length = 50
        tokens_length = tokens.shape[1]
        if tokens_length + _length > 1024:
            return ''

        
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        for index in range(num_replicas):
            text = tokenizer.decode(output[index, :], skip_special_tokens=True)
            offset = len(prompt)
            start = offset + 1
            end = text.find('\n', start)
            outputs.append(text[start:end].split(':')[-1].strip())

    return outputs

In [11]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [12]:
def generate_multiple_y_n_answers(model, prompt, num_replicas=25):
    model.train()
    outputs_count = {}
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        tokens = tokens.repeat(num_replicas,1)
        _length = 50
        tokens_length = tokens.shape[1]
        if tokens_length + _length > 1024:
            return ''

        
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        for index in range(num_replicas):
            text = tokenizer.decode(output[index, :], skip_special_tokens=True)
            answer = get_answer_from_text(text)
            outputs_count.setdefault(answer, 0)
            outputs_count[answer] += 1

    total = sum(v for v in outputs_count.values())
    return [(k, v / total) for k, v in outputs_count.items()]

In [13]:
def generate_first_answer(model, prompt):
    model.eval()
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [14]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [16]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [17]:
doc=0
number = 1
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
first_answer = generate_first_answer(model, small_text)
answers = generate_multiple_answers(model, small_text)

In [18]:
def get_description_from_data_item(item):
    return item['story']

def get_dialogue_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = ''
    text += ' '.join([q['input_text'] + ' ' + a['input_text'] + '.'
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '?'.join(text.split('?')[:-1]) + '?'
    return text

In [19]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]    
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [20]:
device = 'cuda'

In [21]:
import numpy as np
from sentence_transformers import SentenceTransformer


sentence_model = SentenceTransformer('msmarco-distilbert-base-v3')
sentence_model = sentence_model.to(device)

In [22]:
def get_embeddings_from_text(text):
    outputs = sentence_model.encode(text)
    return outputs

def group_similar_answers_and_get_scores(answers):
    answers_dict = {}
    threshold = 0.7
    embeddings = get_embeddings_from_text(answers)
    embeddings = np.array([e/np.linalg.norm(e) for e in embeddings])
    similarity_matrix = np.matmul(embeddings, embeddings.transpose())
    superseded = set()
    superseded_from = {}
    for i in range(len(answers)):
        for j in range(len(answers)):
            if i > j:
                continue
            if i != j and answers[i] == answers[j]:
                continue
            if similarity_matrix[i][j] > threshold :
                answers_dict.setdefault(i, 0)
                answers_dict[i] += 1
                if i != j:
                    superseded.add(j)
                    superseded_from.setdefault(i, [])
                    superseded_from[i].append(j)

    answers_and_scores = [(index, score/len(answers))
                          for index, score in answers_dict.items() 
                          if index not in superseded]
    
    new_scores_dict = {}
    total_score = sum(item[1] for item in answers_and_scores)
    for answer_index, score in answers_and_scores:
        answer_group = [answers[answer_index]]
        if answer_index in superseded_from:
            answer_group += [answers[i] for i in superseded_from[answer_index]]
        answer_group = tuple(set(answer_group))
        if answer_group in new_scores_dict:
            new_scores_dict[answer_group] += score / total_score
        else:
            new_scores_dict[answer_group] = score / total_score
    
    
    return sorted(list(new_scores_dict.items()), key=lambda x: -x[1])

In [23]:
statement_model = GPT2LMHeadModel.from_pretrained('gpt2')
statement_model.cuda()
checkpoint = torch.load('save_statement' + str(9))
statement_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [24]:
def generate_statement_from_dialogue(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [25]:
doc = 0
number = 6
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)

In [26]:
def get_statement_prompt(item, max_num_questions=0, question_number=-1, use_answer=None):
    text = 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if use_answer:
        text = '\n'.join(text.split('\n')[:-1]) + '\n' + 'A: ' + use_answer + '\n'
    text += '\nStatement:\n'
    return text

In [27]:
def select_answer(description, answers, doc, number):
    best_probability = 0
    
    best_answer = ''
    for answer_tuple in answers:
        answer_sample = answer_tuple[0][0]
        answer_score = answer_tuple[1]
        print(answer_score, answer_sample)
        statement_prompt = get_statement_prompt(dev_dict['data'][doc], 
                                 max_num_questions=5,
                                 question_number=number,
                                 use_answer=answer_sample)
        #print(statement_prompt)
        statement = generate_statement_from_dialogue(statement_model, statement_prompt)
        #print(statement)
        text = create_claim_from_description_and_dialogue(description, statement)
        
        y_n_tuples = generate_multiple_y_n_answers(fact_checking_model, text)
        print(y_n_tuples)
        for y_n, score in y_n_tuples:
            if score * answer_score > best_probability and y_n == 'Y':
                best_probability = score * answer_score
                best_answer = answer_tuple[0]

    return best_answer, best_probability

In [28]:
doc = 0
number = 0
description = get_description_from_data_item(dev_dict['data'][doc])
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
answers = generate_multiple_answers(model, small_text)

print(answers)
answers = group_similar_answers_and_get_scores(answers)
select_answer(description, answers, doc, number)

['white', 'White', 'White', 'White', 'white', 'white', 'White', 'white', 'orange', 'white', 'white', 'white', 'white', 'White', 'orange', 'white', 'white', 'White', 'white', 'white', 'white', 'orange', 'white', 'orange', 'White']
0.6666666666666667 White
[('Y', 1.0)]
0.33333333333333337 orange
[('Y', 1.0)]


(('White', 'white'), 0.6666666666666667)

# Computing accuracy after fact checking

In [29]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):

        all_answers = get_all_answers(dev_dict, index)
        total_questions = len(all_answers)        
        
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=8,
                                                 question_number=number,
                                                 last_question=False)
            predictions = generate_multiple_answers(model, small_text)
            predictions = group_similar_answers_and_get_scores(predictions)
            #predictions = predictions[0]
            predictions = select_answer(description, predictions, index, number)
            print(predictions)
            for prediction in predictions[0]:
                it_was_answered = False
                if not prediction:
                    prediction = 'unknown'
                prediction = prediction.replace('.', '').replace('"', '')
                it_was_answered = False
                for label in all_answers[number]:
                    label = label.replace('.', '').replace('"', '')

                    if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                        false_positives.append(prediction)

                    if prediction.lower() == label.lower():
                        correct_answers += 1
                        it_was_answered = True
                        break
                    elif prediction.lower() in label.lower():
                        correct_answers += 1
                        it_was_answered = True
                        break
                    elif label.lower() in prediction.lower():
                        correct_answers += 1
                        it_was_answered = True
                        break
                    else:
                        wrong_predictions.append({'label': label, 'prediction': prediction})
                        
                if it_was_answered:
                    break

            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [30]:
compute_accuracy_of_model(model)

  0%|          | 0/10 [00:00<?, ?it/s]

0.7777777777777777 White
[('Y', 1.0)]
0.22222222222222213 orange
[('Y', 1.0)]
(('White', 'white'), 0.7777777777777777)
0.47058823529411764 in a barn
[('Y', 1.0)]
0.29411764705882354 a farm house
[('Y', 1.0)]
0.1764705882352941 above the barn
[('Y', 1.0)]
0.058823529411764705 a farm
[('Y', 1.0)]
(('in a barn', 'a barn'), 0.47058823529411764)
0.9200000000000002 no
[('Y', 1.0)]
0.07999999999999999 yes
[('Y', 0.96), ('N', 0.04)]
(('no',), 0.9200000000000002)
1.0 her mom and 5 other sisters
[('Y', 1.0)]
(('her mom and 5 other sisters', 'her mommy and sisters', 'her mommy and 5 other sisters', 'her mommy and 5 sisters'), 1.0)
1.0000000000000002 orange
[('Y', 1.0)]
(('orange',), 1.0000000000000002)
1.0000000000000002 no
[('Y', 1.0)]
(('no',), 1.0000000000000002)
0.75 she used her paint
[('Y', 1.0)]
0.16666666666666669 rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be any other way". And with that, Cotton's mommy pic

 10%|█         | 1/10 [00:46<06:58, 46.52s/it]

[('Y', 1.0)]
(('no',), 0.8800000000000001)
1.0000000000000002 Asta
[('Y', 1.0)]
(('Asta',), 1.0000000000000002)
0.4761904761904762 A bottle
[('Y', 1.0)]
0.1904761904761905 It was hard and clear.
[('Y', 1.0)]
0.14285714285714285 It was not a bird's belly.
[('N', 0.36), ('Y', 0.64)]
0.09523809523809525 The bottle floated above them.
[('Y', 1.0)]
0.09523809523809525 what did they see
[('Y', 1.0)]
(('A bottle', 'It was a bottle.', 'a bottle', 'It was a bottle', 'bottle', 'A bottle.'), 0.4761904761904762)
0.5217391304347825 Sharkie
[('Y', 1.0)]
0.26086956521739124 Asta's friend
[('Y', 1.0)]
0.13043478260869562 Asta.
[('Y', 1.0)]
0.08695652173913042 a friend
[('Y', 1.0)]
(('Sharkie', 'Sharkie.'), 0.5217391304347825)
0.9523809523809523 yes
[('Y', 1.0)]
0.047619047619047616 no
[('Y', 1.0)]
(('yes', 'Yes', 'Yes.'), 0.9523809523809523)
1.0000000000000002 Yes
[('Y', 1.0)]
(('Yes',), 1.0000000000000002)
0.5652173913043476 a note
[('Y', 1.0)]
0.434782608695652 A bottle
[('Y', 1.0)]
(('a note',), 0.

 20%|██        | 2/10 [01:23<05:26, 40.87s/it]

[('Y', 1.0)]
(('Yes',), 0.9600000000000002)
0.7999999999999998 elderly Chinese lady
[('Y', 1.0)]
0.11999999999999998 Nicole
[('Y', 1.0)]
0.039999999999999994 a little boy
[('Y', 1.0)]
0.039999999999999994 My doorbell
[('Y', 1.0)]
(('elderly Chinese lady', 'The elderly Chinese lady', 'A Chinese lady', 'a lady', 'the elderly Chinese lady', 'a Chinese lady'), 0.7999999999999998)
0.84 yes
[('Y', 1.0)]
0.16 a paper carrier bag
[('Y', 1.0)]
(('yes', 'Yes'), 0.84)
1.0 a paper carrier bag
[('Y', 1.0)]
(('a paper carrier bag', 'A paper carrier bag', 'Paper carrier bag'), 1.0)
0.6799999999999999 Yes
[('Y', 1.0)]
0.3199999999999999 No
[('Y', 1.0)]
(('Yes',), 0.6799999999999999)
1.0000000000000002 Nicole
[('Y', 1.0)]
(('Nicole',), 1.0000000000000002)
1.0000000000000002 Shanghai
[('Y', 0.52), ('N', 0.48)]
(('Shanghai',), 0.5200000000000001)
0.8421052631578947 She is her grandmother
[('Y', 1.0)]
0.07894736842105263 Her mother bought the house next door last October
[('Y', 1.0)]
0.05263157894736842 S

 20%|██        | 2/10 [02:13<08:54, 66.78s/it]


RuntimeError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 10.76 GiB total capacity; 7.53 GiB already allocated; 24.69 MiB free; 7.96 GiB reserved in total by PyTorch)

## Results with alternative answers

# Results

#### Using dev_dict 10, 20 multiple answers
0.7588652482269503
(without fact checking and 20 multiple answers it is 0.7163120567375887)
(without fact checking and 25 multiple answers it is 0.7659574468085106)

#### Using dev_dict 10, 25 multiple answers



## Repeat the test with save number 7 <= there was an error with the evaluation function

## Todo

* properly train FEVER. Why is it so good on fever dev data but louse on coqa? (leakage?)

* Is the FEVER system skewed to say Y? Should you have a threshold for the "correct" answer?

* maybe you can add some syntethic corereference to the training data for the statements