In [1]:
import os
import json

In [2]:
train_list = json.load(open('../data/adversarial_claims_train.json'))
dev_list = json.load(open('../data/adversarial_claims_dev.json'))

In [4]:
_question_prompt = '\nQ: '
_correct_answer_prompt = '\nCA: '
_wrong_answer_prompt = '\nWA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_correct_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_correct_answer_prompt)

def get_correct_answer_number(text, number):
    pos = text.find(_correct_answer_prompt)
    for _ in range(number):
        pos = text.find(_correct_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_correct_answer_prompt))
    return text[pos + len(_correct_answer_prompt):end]

def get_wrong_answer_number(text, number):
    pos = text.find(_wrong_answer_prompt)
    for _ in range(number):
        pos = text.find(_wrong_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_wrong_answer_prompt))
    return text[pos + len(_wrong_answer_prompt):end]


def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_description_from_text(text):
    start_prompt = 'Story:'
    end_prompt = 'Discussion:'
    return text[text.find(start_prompt) + len(start_prompt):text.find(end_prompt)]

def get_discussion_from_text(text):
    start_prompt = 'Discussion:'
    return text[text.find(start_prompt) + len(start_prompt):].strip()

def get_statement_prompt_from_text(full_text, number, max_questions=5):
    text = 'Discussion:\n'
    questions_and_answers_list = get_discussion_from_text(full_text).split('\n')
    start = max(0, (number + 1 - max_questions) * 3)
    end = (number + 1) * 3
    questions_and_answers_list = questions_and_answers_list[start:end]
    text += '\n'.join(questions_and_answers_list)
    text += '\nStatement:\n'
    return text

In [5]:
def generate_statement_from_dialogue(statement_model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = statement_model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [6]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [8]:
statement_model = GPT2LMHeadModel.from_pretrained('gpt2')
statement_model.cuda()
checkpoint = torch.load('save_statement' + str(2))
statement_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [17]:
def get_correct_statement_prompt_from_text(full_text, number, max_questions=5):    
    text = 'Discussion:\n'
    start = max(0, number - max_questions)
    end = number + 1
    for index in range(start, end):
        text += f'Q: {get_question_number(full_text, index).capitalize()}\n'
        text += f'A: {get_correct_answer_number(full_text, index).replace(".", "").capitalize()}\n'

    text += '\nStatement:\n'
    return text

In [56]:
def get_refuting_statement_prompt_from_text(full_text, number, max_questions=5):    
    text = 'Discussion:\n'
    start = max(0, number - max_questions)
    end = number + 1
    for index in range(start, end):
        text += f'Q: {get_question_number(full_text, index).capitalize()}\n'
        if index != number:
            text += f'A: {get_correct_answer_number(full_text, index).capitalize()}\n'
    text += f'A: {get_wrong_answer_number(full_text, index).replace(".", "").capitalize()}\n'
    text += '\nStatement:\n'
    return text

In [60]:
def get_supporting_claim_from_questions(text, number):
    description = get_description_from_text(text)
    statement_prompt = get_correct_statement_prompt_from_text(text, number)
    statement = generate_statement_from_dialogue(statement_model, statement_prompt)
    return create_claim_from_description_and_dialogue(description, statement) + 'Yes.'

def get_refuting_claim_from_questions(text, number):
    description = get_description_from_text(text)
    statement_prompt = get_refuting_statement_prompt_from_text(text, number)
    statement = generate_statement_from_dialogue(statement_model, statement_prompt)
    return create_claim_from_description_and_dialogue(description, statement) + 'Nope.'

In [61]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if not dialogue:
        return ''
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]    
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [70]:
from tqdm import tqdm

def get_supporting_claims(text_list):
    claims = []
    for text in tqdm(text_list):
        for number in range(get_answers_number(text)):
            claims.append(get_supporting_claim_from_questions(text, number))
    return claims

def get_refuting_claims(text_list):
    claims = []
    for text in tqdm(text_list):
        for number in range(get_answers_number(text)):
            claims.append(get_refuting_claim_from_questions(text, number))
    return claims        

In [74]:
supporting_texts = get_supporting_claims(train_list)
refuting_texts = get_refuting_claims(train_list)

dev_supporting_texts = get_supporting_claims(dev_list)
dev_refuting_texts = get_refuting_claims(dev_list)

100%|██████████| 3968/3968 [38:49<00:00,  1.70it/s]  
100%|██████████| 3968/3968 [39:02<00:00,  1.69it/s]  
100%|██████████| 327/327 [03:26<00:00,  1.59it/s]
100%|██████████| 327/327 [03:26<00:00,  1.58it/s]


In [84]:
print(refuting_texts[6])

Evidence:

CHAPTER XXIV. THE INTERRUPTED MASS 
The morning of that Wednesday of Corpus Christi, fateful to all concerned in this chronicle, dawned misty and grey, and the air was chilled by the wind that blew from the sea. The chapel bell tinkled out its summons, and the garrison trooped faithfully to Mass. 
Presently came Monna Valentina, followed by her ladies, her pages, and lastly, Peppe, wearing under his thin mask of piety an air of eager anxiety and unrest. Valentina was very pale, and round her eyes there were dark circles that told of sleeplessness, and as she bowed her head in prayer, her ladies observed that tears were falling on the illuminated Mass-book over which she bent. And now came Fra Domenico from the sacristy in the white chasuble that the Church ordains for the Corpus Christi feast, followed by a page in a clerkly gown of black, and the Mass commenced. 
There were absent only from the gathering Gonzaga and Fortemani, besides a sentry and the three prisoners. Franc

In [76]:
len(refuting_texts)

6862

In [77]:
json.dump(supporting_texts, open('../data/adv_supporting.json', 'w'))
json.dump(refuting_texts, open('../data/adv_refuting.json', 'w'))

In [78]:
json.dump(dev_supporting_texts, open('../data/adv_dev_supporting.json', 'w'))
json.dump(dev_refuting_texts, open('../data/adv_dev_refuting.json', 'w'))