In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
fact_checking_model = GPT2LMHeadModel.from_pretrained('gpt2')
fact_checking_model.cuda()
checkpoint = torch.load('save_fever2')
fact_checking_model.load_state_dict(checkpoint['model_state_dict'])
_ = fact_checking_model.eval()

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [5]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [6]:
def generate_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    to_return = tokenizer.decode(output[0], skip_special_tokens=True)
    perplexity = float(model(output, labels=output)[0])
    return get_answer_from_text(to_return), perplexity

# Question Answering part

In [7]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [8]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [9]:
train_dict = json.load(open('../data/coqa-train-v1.0.json', encoding='utf8'))
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [10]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [11]:
def get_description_from_data_item(item):
    return item['story']

def get_dialogue_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = ''
    text += ' '.join([q['input_text'] + ' ' + a['input_text'] + '.'
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '?'.join(text.split('?')[:-1]) + '?'
    return text

In [12]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [13]:
def get_positive_list_of_claims():
    claims = []
    
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []
    false_positives = []
    dlist = train_list
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(train_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(train_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            description = get_description_from_data_item(train_dict['data'][index])
            dialogue = get_dialogue_from_data_item(train_dict['data'][index],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=True)
            
            last_answer = dialogue.split('?')[-1].strip().lower()
            if stripped_answer(last_answer) not in ['yes', 'no']:
                continue
        
            claims.append(create_claim_from_description_and_dialogue(description, dialogue) + 'Y')
    return claims

In [14]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_small' + str(1))
model.load_state_dict(checkpoint['model_state_dict'])
_ = model.eval()

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
def generate_multiple_answers(model, prompt, num_replicas=2):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    outputs = []
    for _ in range(num_replicas):
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        output = tokenizer.decode(output[0], skip_special_tokens=True)
        offset = len(prompt)
        start = offset + 1
        end = output.find('\n', start)
        outputs.append(output[start:end].split(':')[-1].strip())
        
    return list(set(outputs))

In [16]:
def stripped_answer(text):
    return text.lower().strip().replace('.', '')

In [17]:
def answers_are_consistent(lhs, rhs):
    prediction = stripped_answer(lhs)
    label = stripped_answer(rhs)
    if prediction.lower() == label.lower():
        return True
    elif prediction.lower() in label.lower():
        return True
    elif label.lower() in prediction.lower():
        return True
    return False

In [18]:
import random
random.seed(42)

def get_negative_list_of_claims():
    claims = []
    
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = train_list
    random.shuffle(dlist)
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(train_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(train_dict['data'][index],
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            description = get_description_from_data_item(train_dict['data'][index])
            dialogue = get_dialogue_from_data_item(train_dict['data'][index],
                           max_num_questions=5, 
                           question_number=number,
                           last_question=True)
            last_answer = dialogue.split('?')[-1].strip().lower()
            dialogue = '?'.join(dialogue.split('?')[:-1]) + '?'
            wrong_answer = ''
            if stripped_answer(last_answer) == 'yes':
                wrong_answer = 'no'
            elif stripped_answer(last_answer) == 'no':
                wrong_answer = 'yes'

            #else:
            #    generated_answers = generate_multiple_answers(model, small_text)
            #    for candidate in generated_answers:
            #        if not answers_are_consistent(candidate, last_answer):
            #            wrong_answer = candidate
            #            break

            if not wrong_answer:
                continue
            claims.append(create_claim_from_description_and_dialogue(description, dialogue + ' ' + wrong_answer) + 'N')
    return claims

In [19]:
positive_claims = get_positive_list_of_claims()

100%|██████████| 7199/7199 [00:01<00:00, 6272.77it/s]


In [20]:
negative_claims = get_negative_list_of_claims()

100%|██████████| 7199/7199 [00:01<00:00, 6238.40it/s]


In [21]:
print(positive_claims[7])

Evidence:
(CNN) -- The longest-running holiday special still has a very shiny nose. 
"Rudolph the Red-Nosed Reindeer" premiered on television December 6, 1964, and is now one of the holiday season's perennial favorites. The story of the reindeer who saves Christmas is beloved among children and adults alike. 
The Rankin-Bass animated film production company used Japanese puppets and stop motion to tell the tale, bolstered by a soundtrack featuring Burl Ives' rendition of the theme song. 
In the story, Santa's reindeer Donner and his wife have a son, Rudolph, who has the distinction of a nose that glows. He runs away after being made to feel an outcast and links up with an elf who dreams of becoming a dentist and an adventurer seeking silver and gold. 
After ending up on the Island of Misfit Toys and wandering for a while, Rudolph goes on to save his loved ones from the Abominable Snow Monster and guides Santa through a blizzard that threatens to ruin Christmas. 
In 2006, the New York T

In [22]:
print(negative_claims[5])

Evidence:
CHAPTER XXIV. THE INTERRUPTED MASS 
The morning of that Wednesday of Corpus Christi, fateful to all concerned in this chronicle, dawned misty and grey, and the air was chilled by the wind that blew from the sea. The chapel bell tinkled out its summons, and the garrison trooped faithfully to Mass. 
Presently came Monna Valentina, followed by her ladies, her pages, and lastly, Peppe, wearing under his thin mask of piety an air of eager anxiety and unrest. Valentina was very pale, and round her eyes there were dark circles that told of sleeplessness, and as she bowed her head in prayer, her ladies observed that tears were falling on the illuminated Mass-book over which she bent. And now came Fra Domenico from the sacristy in the white chasuble that the Church ordains for the Corpus Christi feast, followed by a page in a clerkly gown of black, and the Mass commenced. 
There were absent only from the gathering Gonzaga and Fortemani, besides a sentry and the three prisoners. France

In [23]:
import random
split = 0.8
all_list = positive_claims + negative_claims
random.shuffle(all_list)
train_list = all_list[:int(len(all_list) * split)]
dev_list = all_list[int(len(all_list) * split):]

In [24]:
json.dump(train_list, open('../data/train_fc_with_qa.json', 'w'))
json.dump(dev_list, open('../data/dev_fc_with_qa.json', 'w'))

In [25]:
print(train_list[2])

Evidence:
My summer hols wr CWOT. B4, we usd 2 go 2 NY 2C my bro, his GF & thr3:-@ kids FTF. ILNY, it's gr8. Can you understand this sentence? If you can't, don't feel too bad; neither could the middle school teacher in England who received this as homework. This is Netspeak: the language of computerized communication found on Internet or cell phones. To new comers, it can look like a completely foreign language. So, what is the translation of the sentence above? My summer holidays were a complete waste of time. Before, we used to go to New York to see my brother, his girlfriend, and their three screaming kids face to face. I love New York. It's great. School teachers and parents say this new form of writing is harming the English language. Increasing spelling and grammatical mistakes can be seen in students' writing. They fear the language could become corrupted . "Everyone should just relax", say linguists . They believe Netspeak is in fact more of a good thing. David Crystal, from t