In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
train_list = json.load(open('../data/train_fc_with_qa.json'))[:3000]
dev_list = json.load(open('../data/dev_fc_with_qa.json'))

In [3]:
print(train_list[0])

Evidence:
CHAPTER XII 
Königstein 
Phineas Finn and Lady Laura Kennedy sat together discussing the affairs of the past till the servant told them that "My Lord" was in the next room, and ready to receive Mr. Finn. "You will find him much altered," said Lady Laura, "even more than I am." 
"I do not find you altered at all." 
"Yes, you do,--in appearance. I am a middle-aged woman, and conscious that I may use my privileges as such. But he has become quite an old man,--not in health so much as in manner. But he will be very glad to see you." So saying she led him into a room, in which he found the Earl seated near the fireplace, and wrapped in furs. He got up to receive his guest, and Phineas saw at once that during the two years of his exile from England Lord Brentford had passed from manhood to senility. He almost tottered as he came forward, and he wrapped his coat around him with that air of studious self-preservation which belongs only to the infirm. 
"It is very good of you to come 

In [4]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_fever2')
model.load_state_dict(checkpoint['model_state_dict'])
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1044 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1148 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1189 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1423 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1036 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Skipped 6 out of 3000


In [8]:
train_batches = batchify(data, 2)

In [9]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

In [10]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(40):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    #test(model, dev_list[:2000])
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_fever_with_qa_data_' + str(epoch))
    scheduler.step()

100%|██████████| 1612/1612 [03:20<00:00,  8.03it/s]


Epoch: 0 Loss: 3.392703871351318


100%|██████████| 1612/1612 [03:27<00:00,  7.77it/s]


Epoch: 1 Loss: 3.167952400505099


100%|██████████| 1612/1612 [03:27<00:00,  7.77it/s]


Epoch: 2 Loss: 3.0796982589046062


100%|██████████| 1612/1612 [03:27<00:00,  7.78it/s]


Epoch: 3 Loss: 3.0223146847844418


100%|██████████| 1612/1612 [03:29<00:00,  7.68it/s]


Epoch: 4 Loss: 2.974780977216015


 19%|█▉        | 308/1612 [00:40<02:52,  7.54it/s]


KeyboardInterrupt: 

In [11]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
_ = model.eval()

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [13]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [14]:
def generate_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_from_text(output)

In [15]:
import sys
import traceback

def test(model, data):
    model.eval()
    tp = 0
    fp = 0
    fn = 0

    skipped = 0

    for item in tqdm(data):
        expected = get_answer_from_text(item)
        predicted = ''
        try:
            predicted = generate_answer(model, item)
        except (IndexError, RuntimeError) as e:
            skipped += 1
            continue
        if expected == predicted:
            tp += 1
        elif expected == 'N' and predicted == 'Y':
            fp += 1
        elif expected == 'Y' and predicted == 'N':
            fn += 1
        else:
            fp += 1

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    print('Precision:', precision)
    print('Recall:', recall)
    print('F1:', f1)
    print('Skipped:', skipped)

In [16]:
for epoch in range(11):
    checkpoint = torch.load('save_fever_with_qa_data_' + str(epoch))
    model.load_state_dict(checkpoint['model_state_dict'])
    print('epoch:', epoch)
    test(model, dev_list)

  0%|          | 3/7782 [00:00<04:25, 29.30it/s]

epoch: 0


 15%|█▌        | 1187/7782 [00:37<03:33, 30.91it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3217/7782 [01:42<02:25, 31.43it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3366/7782 [01:47<02:16, 32.40it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4232/7782 [02:14<01:48, 32.71it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4413/7782 [02:20<01:45, 31.88it/s]Token ind

Precision: 0.7073329007138222
Recall: 0.7304406098173899
F1: 0.7187010632160225
Skipped: 9
epoch: 1


 15%|█▌        | 1190/7782 [00:37<03:37, 30.28it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3217/7782 [01:43<02:26, 31.16it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3367/7782 [01:48<02:16, 32.39it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:15<01:47, 32.85it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4413/7782 [02:21<01:48, 31.04it/s]Token ind

Precision: 0.767374681393373
Recall: 0.7051842598376015
F1: 0.7349662299617544
Skipped: 9
epoch: 2


 15%|█▌        | 1190/7782 [00:37<03:28, 31.57it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3215/7782 [01:41<02:22, 31.97it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3366/7782 [01:46<02:14, 32.92it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:14<01:44, 33.92it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4415/7782 [02:20<01:46, 31.51it/s]Token ind

Precision: 0.7406927808352217
Recall: 0.7415329768270945
F1: 0.7411126406996519
Skipped: 9
epoch: 3


 15%|█▌        | 1190/7782 [00:37<03:30, 31.32it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3215/7782 [01:41<02:22, 32.08it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3364/7782 [01:46<02:12, 33.38it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:13<01:47, 33.13it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4414/7782 [02:19<01:48, 31.06it/s]Token ind

Precision: 0.7383804504072832
Recall: 0.7535452322738386
F1: 0.7458857696030978
Skipped: 9
epoch: 4


 15%|█▌        | 1187/7782 [00:37<03:35, 30.57it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3215/7782 [01:43<02:26, 31.26it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3367/7782 [01:47<02:20, 31.53it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:15<01:47, 33.00it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4414/7782 [02:20<01:49, 30.71it/s]Token ind

Precision: 0.7104009832539561
Recall: 0.7853260869565217
F1: 0.745986932322336
Skipped: 9


  0%|          | 4/7782 [00:00<03:27, 37.45it/s]

epoch: 5


 15%|█▌        | 1188/7782 [00:37<03:38, 30.21it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3217/7782 [01:43<02:25, 31.41it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3365/7782 [01:47<02:19, 31.74it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:15<01:45, 33.53it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4413/7782 [02:21<01:45, 31.81it/s]Token ind

Precision: 0.8885393258426967
Recall: 0.8437900128040973
F1: 0.8655866900175131
Skipped: 9


  0%|          | 4/7782 [00:00<03:28, 37.34it/s]

epoch: 6


 15%|█▌        | 1188/7782 [00:37<03:33, 30.92it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3215/7782 [01:43<02:23, 31.77it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3367/7782 [01:47<02:20, 31.35it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:15<01:48, 32.61it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4415/7782 [02:21<01:50, 30.39it/s]Token ind

Precision: 0.9538286457486588
Recall: 0.7834156763252771
F1: 0.8602639296187683
Skipped: 9


  0%|          | 4/7782 [00:00<03:29, 37.19it/s]

epoch: 7


 15%|█▌        | 1190/7782 [00:37<03:34, 30.69it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3215/7782 [01:43<02:26, 31.27it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3366/7782 [01:48<02:18, 31.87it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 54%|█████▍    | 4234/7782 [02:15<01:48, 32.59it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors
 57%|█████▋    | 4413/7782 [02:21<01:47, 31.39it/s]Token ind

Precision: 0.971004832527912
Recall: 0.7668114225555994
F1: 0.8569117647058823
Skipped: 9


  0%|          | 4/7782 [00:00<03:31, 36.82it/s]

epoch: 8


 15%|█▌        | 1188/7782 [00:37<03:38, 30.22it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
 41%|████▏     | 3217/7782 [01:43<02:27, 31.02it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
 43%|████▎     | 3365/7782 [01:47<02:15, 32.68it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors
 53%|█████▎    | 4133/7782 [02:12<01:57, 31.18it/s]


KeyboardInterrupt: 

In [None]:
print(dev_list[0])

In [None]:
generate_answer(model, dev_list[46])

In [None]:
checkpoint = torch.load('save_fever_with_qa_data_' + str(0))
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [None]:
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [None]:
def get_description_from_data_item(item):
    return item['story']

def get_dialogue_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = ''
    text += ' '.join([q['input_text'] + ' ' + a['input_text'] + '.'
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '?'.join(text.split('?')[:-1]) + '?'
    return text

In [None]:
def create_claim_from_description_and_dialogue(description, dialogue):
    if dialogue[-1] == '.':
        dialogue = dialogue[:-1]    
    text = 'Evidence:\n'
    text += description.replace('\n\n', '\n') + '\n\n'
    text += 'Claim:\n'
    text += dialogue + '\n\n'
    text += 'The evidence supports the claim:\n'
    return text

In [None]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [None]:
doc = 0
number = 0
description = get_description_from_data_item(dev_dict['data'][doc])
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
dialogue = get_dialogue_from_data_item(dev_dict['data'][doc],
                                       max_num_questions=5, 
                                       question_number=number,
                                       last_question=False)
claim = create_claim_from_description_and_dialogue(description, dialogue + ' airplane')

In [None]:
print(claim)

In [None]:
generate_answer(model, claim)