In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
train_list = json.load(open('../data/train_fc_with_qa.json'))[:3000]
dev_list = json.load(open('../data/dev_fc_with_qa.json'))

In [3]:
print(train_list[0])

Evidence:
CHAPTER XII 
Königstein 
Phineas Finn and Lady Laura Kennedy sat together discussing the affairs of the past till the servant told them that "My Lord" was in the next room, and ready to receive Mr. Finn. "You will find him much altered," said Lady Laura, "even more than I am." 
"I do not find you altered at all." 
"Yes, you do,--in appearance. I am a middle-aged woman, and conscious that I may use my privileges as such. But he has become quite an old man,--not in health so much as in manner. But he will be very glad to see you." So saying she led him into a room, in which he found the Earl seated near the fireplace, and wrapped in furs. He got up to receive his guest, and Phineas saw at once that during the two years of his exile from England Lord Brentford had passed from manhood to senility. He almost tottered as he came forward, and he wrapped his coat around him with that air of studious self-preservation which belongs only to the infirm. 
"It is very good of you to come 

In [4]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_fever4')
model.load_state_dict(checkpoint['model_state_dict'])
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)

In [7]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1044 > 1024). Running this sequence through the model will result in indexing errors


Skipped 6 out of 3000


In [8]:
train_batches = batchify(data, 2)

In [9]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

In [10]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(40):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    #test(model, dev_list[:2000])
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_fever_with_qa_data_' + str(epoch))
    scheduler.step()

100%|██████████| 1612/1612 [03:14<00:00,  8.27it/s]


Epoch: 0 Loss: 3.6659687318222103


100%|██████████| 1612/1612 [03:18<00:00,  8.13it/s]


Epoch: 1 Loss: 3.360404967818012


100%|██████████| 1612/1612 [03:18<00:00,  8.14it/s]


Epoch: 2 Loss: 3.255894891233657


100%|██████████| 1612/1612 [03:18<00:00,  8.14it/s]


Epoch: 3 Loss: 3.1890507712035854


100%|██████████| 1612/1612 [03:18<00:00,  8.14it/s]


Epoch: 4 Loss: 3.134766007438487


100%|██████████| 1612/1612 [03:18<00:00,  8.12it/s]


Epoch: 5 Loss: 3.0933316490120686


100%|██████████| 1612/1612 [03:38<00:00,  7.39it/s]


Epoch: 6 Loss: 3.0584815394124085


100%|██████████| 1612/1612 [03:38<00:00,  7.39it/s]


Epoch: 7 Loss: 3.028677450885844


100%|██████████| 1612/1612 [03:38<00:00,  7.39it/s]


Epoch: 8 Loss: 3.003732809357844


100%|██████████| 1612/1612 [03:38<00:00,  7.39it/s]


Epoch: 9 Loss: 2.9828132283746753


100%|██████████| 1612/1612 [03:37<00:00,  7.40it/s]


Epoch: 10 Loss: 2.9628975910155413


100%|██████████| 1612/1612 [03:38<00:00,  7.39it/s]


Epoch: 11 Loss: 2.9466610160418245


100%|██████████| 1612/1612 [03:37<00:00,  7.40it/s]


Epoch: 12 Loss: 2.9318274722871354


100%|██████████| 1612/1612 [03:37<00:00,  7.40it/s]


Epoch: 13 Loss: 2.920623479166043


 59%|█████▊    | 947/1612 [02:07<01:29,  7.41it/s]


KeyboardInterrupt: 

In [11]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
_ = model.eval()

In [12]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [13]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [14]:
def generate_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_from_text(output)

In [15]:
import sys
import traceback

def test(model, data):
    model.eval()
    tp = 0
    fp = 0
    fn = 0

    skipped = 0

    for item in tqdm(data):
        expected = get_answer_from_text(item)
        predicted = ''
        try:
            predicted = generate_answer(model, item)
        except (IndexError, RuntimeError) as e:
            skipped += 1
            continue
        if expected == predicted:
            tp += 1
        elif expected == 'N' and predicted == 'Y':
            fp += 1
        elif expected == 'Y' and predicted == 'N':
            fn += 1
        else:
            fp += 1

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    print('Precision:', precision)
    print('Recall:', recall)
    print('F1:', f1)
    print('Skipped:', skipped)

In [16]:
for epoch in range(14):
    checkpoint = torch.load('save_fever_with_qa_data_' + str(epoch))
    model.load_state_dict(checkpoint['model_state_dict'])
    print('epoch:', epoch)
    test(model, dev_list)

  0%|          | 4/7782 [00:00<03:22, 38.44it/s]

epoch: 0


 15%|█▌        | 1188/7782 [00:29<02:48, 39.13it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 7782/7782 [03:15<00:00, 39.91it/s]


Precision: 0.6151031757646714
Recall: 0.8172272992416877
F1: 0.7019038076152304
Skipped: 9


  0%|          | 5/7782 [00:00<02:48, 46.06it/s]

epoch: 1


100%|██████████| 7782/7782 [03:13<00:00, 40.31it/s]


Precision: 0.5623746490172483
Recall: 0.9346666666666666
F1: 0.7022289005760081
Skipped: 9


  0%|          | 4/7782 [00:00<03:23, 38.25it/s]

epoch: 2


100%|██████████| 7782/7782 [03:12<00:00, 40.45it/s]


Precision: 0.5816591499248326
Recall: 0.9032258064516129
F1: 0.707623243827417
Skipped: 9


  0%|          | 5/7782 [00:00<02:43, 47.45it/s]

epoch: 3


100%|██████████| 7782/7782 [03:12<00:00, 40.50it/s]


Precision: 0.6177093630843101
Recall: 0.8634030793525463
F1: 0.7201778216843665
Skipped: 9


  0%|          | 5/7782 [00:00<02:46, 46.73it/s]

epoch: 4


100%|██████████| 7782/7782 [03:11<00:00, 40.57it/s]


Precision: 0.6156328936289643
Recall: 0.8714739769566945
F1: 0.7215460526315789
Skipped: 9


  0%|          | 5/7782 [00:00<02:47, 46.39it/s]

epoch: 5


100%|██████████| 7782/7782 [03:12<00:00, 40.50it/s]


Precision: 0.6838246296573395
Recall: 0.8074204946996466
F1: 0.7405006886494369
Skipped: 9


  0%|          | 5/7782 [00:00<02:50, 45.64it/s]

epoch: 6


100%|██████████| 7782/7782 [03:12<00:00, 40.36it/s]


Precision: 0.6156302054414214
Recall: 0.886290967226219
F1: 0.726572739187418
Skipped: 9


  0%|          | 5/7782 [00:00<02:44, 47.40it/s]

epoch: 7


100%|██████████| 7782/7782 [03:12<00:00, 40.46it/s]


Precision: 0.6044980800877674
Recall: 0.9016158723665372
F1: 0.7237501026188327
Skipped: 9


  0%|          | 5/7782 [00:00<02:43, 47.47it/s]

epoch: 8


100%|██████████| 7782/7782 [03:12<00:00, 40.36it/s]


Precision: 0.640299985849724
Recall: 0.8650353660867903
F1: 0.7358920149617825
Skipped: 9


  0%|          | 5/7782 [00:00<02:43, 47.68it/s]

epoch: 9


100%|██████████| 7782/7782 [03:12<00:00, 40.52it/s]


Precision: 0.6397629462395936
Recall: 0.8685823754789272
F1: 0.7368164459250832
Skipped: 9


  0%|          | 5/7782 [00:00<02:43, 47.44it/s]

epoch: 10


100%|██████████| 7782/7782 [03:12<00:00, 40.35it/s]


Precision: 0.624012200194094
Recall: 0.8893499308437067
F1: 0.7334202379012545
Skipped: 9


  0%|          | 5/7782 [00:00<02:45, 46.98it/s]

epoch: 11


100%|██████████| 7782/7782 [03:11<00:00, 40.58it/s]


Precision: 0.6496029495178672
Recall: 0.8640135797812146
F1: 0.7416221466731424
Skipped: 9


  0%|          | 5/7782 [00:00<02:44, 47.34it/s]

epoch: 12


100%|██████████| 7782/7782 [03:11<00:00, 40.54it/s]


Precision: 0.649370846882511
Recall: 0.8677498583034197
F1: 0.7428432799611839
Skipped: 9


  0%|          | 5/7782 [00:00<02:53, 44.75it/s]

epoch: 13


100%|██████████| 7782/7782 [03:10<00:00, 40.75it/s]

Precision: 0.6372398379661964
Recall: 0.8813755795981453
F1: 0.739683826509931
Skipped: 9





In [21]:
print(dev_list[43])

Evidence:
Copyright infringement is the use of works protected by copyright law without permission, infringing certain exclusive rights granted to the copyright holder, such as the right to reproduce, distribute, display or perform the protected work, or to make derivative works. The copyright holder is typically the work's creator, or a publisher or other business to whom copyright has been assigned. Copyright holders routinely invoke legal and technological measures to prevent and penalize copyright infringement. 
Copyright infringement disputes are usually resolved through direct negotiation, a notice and take down process, or litigation in civil court. Egregious or large-scale commercial infringement, especially when it involves counterfeiting, is sometimes prosecuted via the criminal justice system. Shifting public expectations, advances in digital technology, and the increasing reach of the Internet have led to such widespread, anonymous infringement that copyright-dependent indu

In [22]:
generate_answer(model, dev_list[43])

'N'

In [23]:
checkpoint = torch.load('save_fever_with_qa_data_' + str(12))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [24]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [25]:
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [26]:
def get_description_from_data_item(item):
    return item['story']

def get_dialogue_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = ''
    text += ' '.join([q['input_text'] + ' ' + a['input_text'] + '.'
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '?'.join(text.split('?')[:-1]) + '?'
    return text

In [27]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]    
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [28]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [50]:
doc = 0
number = 0
description = get_description_from_data_item(dev_dict['data'][doc])
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
dialogue = get_dialogue_from_data_item(dev_dict['data'][doc],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=False)
claim = create_claim_from_description_and_dialogue(description, dialogue + ' blue.')

In [51]:
print(claim)

Evidence:
Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. 
"What are you doing, Cotton?!" 
"I only wanted to be more like you". 
Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to b

In [52]:
generate_answer(model, claim)

'Y'