In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
supporting_texts = json.load(open('../data/supporting.json'))
refuting_texts = json.load(open('../data/refuting.json'))

In [3]:
import random
split = 0.8

_prompt = '\n\n\nThe evidence supports the claim:\n'
all_list = [item + _prompt + 'Yes.' for item in supporting_texts]
all_list += [item + _prompt + 'Nope.' for item in refuting_texts]
random.shuffle(all_list)
train_list = all_list[:int(len(all_list) * split)]
dev_list = all_list[int(len(all_list) * split):]

In [4]:
del all_list

In [5]:
json.dump(train_list, open('../data/train_list.json', 'w'))
json.dump(dev_list, open('../data/dev_list.json', 'w'))

In [6]:
train_list = json.load(open('../data/train_list.json'))
dev_list = json.load(open('../data/dev_list.json'))

In [7]:
print(train_list[12])

Evidence:
Cristiano Ronaldo dos Santos Aveiro, ([kɾiʃ ` tjɐnu ʁuˈnaɫdu], born 5 February 1985) is a Portuguese professional footballer for Spanish club Real Madrid and the Portugal national team. He primarily plays as a forward, but has also been deployed as a winger and serves as captain for Portugal. In 2008, he won his first Ballon d'Or and FIFA World Player of the Year awards. Ronaldo then won the FIFA Ballon d'Or in 2013 and 2014. In 2016, he received his fourth Ballon d'Or, the most for a European player in the history of the award, and the inaugural Best FIFA Men 's Player. In 2015, Ronaldo scored his 500th senior career goal for club and country. Two years later, he surpassed Jimmy Greaves as the all-time top-scorer in the top five European leagues. Often ranked the best player in the world and widely regarded as one of the greatest of all time, Ronaldo was named the best Portuguese player of all time by the Portuguese Football Federation, during its 100th anniversary celebrati

In [8]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [9]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [10]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [11]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors


Skipped 269 out of 176138


In [12]:
train_batches = batchify(data, 2)

In [13]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

In [14]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    #test(model, dev_list[:2000])
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_fever' + str(epoch))
    scheduler.step()

100%|██████████| 88167/88167 [2:18:30<00:00, 10.61it/s]  


Epoch: 0 Loss: 1.0872557028125505


100%|██████████| 88167/88167 [2:16:54<00:00, 10.73it/s]  


Epoch: 1 Loss: 0.35215188638462147


100%|██████████| 88167/88167 [2:16:22<00:00, 10.77it/s]  


Epoch: 2 Loss: 0.24236899872764053


100%|██████████| 88167/88167 [2:16:07<00:00, 10.80it/s]  


Epoch: 3 Loss: 0.21050083176142717


100%|██████████| 88167/88167 [2:15:55<00:00, 10.81it/s]  


Epoch: 4 Loss: 0.18995982896558686


 22%|██▏       | 19640/88167 [30:19<1:45:48, 10.79it/s]


KeyboardInterrupt: 

## Testing

In [42]:
import sys
import traceback

def test(model, data):
    model.eval()
    tp = 0
    fp = 0
    fn = 0

    skipped = 0

    for item in tqdm(data):
        expected = get_answer_from_text(item)
        predicted = ''
        try:
            predicted = generate_answer(model, item)
        except (IndexError, RuntimeError) as e:
            #print(str(e))
            #exc_type, exc_value, exc_traceback = sys.exc_info()
            #print(repr(traceback.extract_tb(exc_traceback)))
            skipped += 1
            continue

        if expected == predicted:
            tp += 1
        if expected == 'N' and predicted == 'Y':
            fp += 1
        if expected == 'Y' and predicted == 'N':
            fn += 1

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    print('Precision:', precision)
    print('Recall:', recall)
    print('F1:', f1)
    print('Skipped:', skipped)

In [16]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [17]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [None]:
def generate_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_from_text(output)

In [131]:
def generate_full_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 3
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    score = model(output, labels=output)[0]
    out_text = tokenizer.decode(output[0][tokens_length:], skip_special_tokens=True)

    return out_text, float(score)

In [19]:
get_answer_from_text(dev_list[0])

'N'

In [22]:
generate_answer(model, dev_list[0])

'N'

In [43]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
for epoch in range(3, 5):
    checkpoint = torch.load(f'save_fever{epoch}')
    model.load_state_dict(checkpoint['model_state_dict'])
    _ = model.eval()
    print(f'Epoch {epoch}')
    test(model, dev_list)

  0%|          | 4/44035 [00:00<21:25, 34.24it/s]

Epoch 3


  0%|          | 47/44035 [00:00<13:16, 55.25it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1515 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 44035/44035 [12:13<00:00, 60.02it/s]
  0%|          | 0/44035 [00:00<?, ?it/s]

Precision: 0.9659403378237863
Recall: 0.9838546447315776
F1: 0.974815194832451
Skipped: 72
Epoch 4


100%|██████████| 44035/44035 [12:06<00:00, 60.61it/s]

Precision: 0.9719283593170007
Recall: 0.9799078427244872
F1: 0.975901790186007
Skipped: 72





## Tests with dev from COQA

In [81]:
epoch = 4
checkpoint = torch.load(f'save_fever{epoch}')
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [82]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [83]:
print(dev_dict['data'][0]['story'])

Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. 

"What are you doing, Cotton?!" 

"I only wanted to be more like you". 

Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be any o

In [153]:
story = dev_dict['data'][0]['story'].replace('\n\n', '\n').replace('\n', '')

text = f"""
Evidence:
{story}

Claim:
What color was Cotton? white.

The evidence supports the claim:
"""[1:]


In [154]:
print(text)

Evidence:
Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. "What are you doing, Cotton?!" "I only wanted to be more like you". Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be a

In [155]:
generate_full_answer(model, text)

('Yes.\n\n', 7.758946418762207)