In [1]:
import json
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
num_epochs = 8

In [3]:
supporting_texts = json.load(open('../data/supporting.json'))
refuting_texts = json.load(open('../data/refuting.json'))

dev_supporting_texts = json.load(open('../data/dev_supporting.json'))
dev_refuting_texts = json.load(open('../data/dev_refuting.json'))

In [4]:
import random
split = 0.8

_prompt = '\n\n\nThe evidence supports the claim:\n'
train_list = [item + _prompt + 'Yes.' for item in supporting_texts]
train_list += [item + _prompt + 'Nope.' for item in refuting_texts]
random.shuffle(train_list)

dev_list = [item + _prompt + 'Yes.' for item in dev_supporting_texts]
dev_list += [item + _prompt + 'Nope.' for item in dev_refuting_texts]
random.shuffle(dev_list)


In [5]:
json.dump(train_list, open('../data/train_list.json', 'w'))
json.dump(dev_list, open('../data/dev_list.json', 'w'))

In [6]:
train_list = json.load(open('../data/train_list.json'))
dev_list = json.load(open('../data/dev_list.json'))

In [9]:
print(dev_list[0])

Evidence:
James Andrew Jones (born October 4, 1980) is an American professional basketball player for the Cleveland Cavaliers of the National Basketball Association (NBA). He currently serves as the secretary-treasurer of the National Basketball Players Association. Jones was a four-year letterman at American High School in Hialeah, Florida. He averaged 25 points per game as a senior, earning Class 6A Player of the Year and First Team All-State honors . He then played college basketball for the Miami Hurricanes of the University of Miami, where he was a three-year starter and finished his career averaging 11 points per game. He was named Third Team All-Big East his junior year and Second Team Verizon Academic All-American his senior year. He was inducted into the University of Miami Sports Hall of Fame in 2014.


Claim:
James Jones has been referred to as the "Champ".


The evidence supports the claim:
Yes.


In [10]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [11]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [12]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [None]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

In [None]:
train_batches = batchify(data, 1)

In [None]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

In [16]:
random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(num_epochs):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_fever' + str(epoch))
    scheduler.step()

100%|██████████| 219832/219832 [4:00:53<00:00, 15.21it/s]  


Epoch: 0 Loss: 0.8349866022318717


100%|██████████| 219832/219832 [4:02:03<00:00, 15.14it/s]  


Epoch: 1 Loss: 0.27282461847403583


100%|██████████| 219832/219832 [3:59:59<00:00, 15.27it/s]  


Epoch: 2 Loss: 0.2118869653484744


100%|██████████| 219832/219832 [3:59:33<00:00, 15.29it/s]  


Epoch: 3 Loss: 0.1911627285202235


100%|██████████| 219832/219832 [4:04:37<00:00, 14.98it/s]    


Epoch: 4 Loss: 0.17449108368149846


100%|██████████| 219832/219832 [3:58:51<00:00, 15.34it/s]  


Epoch: 5 Loss: 0.1654880905636952


100%|██████████| 219832/219832 [4:08:54<00:00, 14.72it/s]  


Epoch: 6 Loss: 0.15662710911507668


100%|██████████| 219832/219832 [4:17:08<00:00, 14.25it/s]  


Epoch: 7 Loss: 0.15177448879509217


## Testing

In [13]:
import sys
import traceback

def test(model, data):
    model.eval()
    tp = 0
    fp = 0
    fn = 0

    skipped = 0

    for item in tqdm(data):
        expected = get_answer_from_text(item)
        predicted = ''
        try:
            predicted = generate_answer(model, item)
        except (IndexError, RuntimeError) as e:
            #print(str(e))
            #exc_type, exc_value, exc_traceback = sys.exc_info()
            #print(repr(traceback.extract_tb(exc_traceback)))
            skipped += 1
            continue

        if expected == predicted:
            tp += 1
        if expected == 'N' and predicted == 'Y':
            fp += 1
        if expected == 'Y' and predicted == 'N':
            fn += 1

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    print('Precision:', precision)
    print('Recall:', recall)
    print('F1:', f1)
    print('Skipped:', skipped)

In [14]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [15]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [16]:
def generate_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_from_text(output)

In [17]:
def generate_full_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 3
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    score = model(output, labels=output)[0]
    out_text = tokenizer.decode(output[0][tokens_length:], skip_special_tokens=True)

    return out_text, float(score)

In [20]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load(f'save_fever5')
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [21]:
get_answer_from_text(dev_list[1])

'N'

In [22]:
generate_answer(model, dev_list[1])

'N'

In [25]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
#for epoch in range(0, num_epochs):
for epoch in range(0, 3):
    checkpoint = torch.load(f'save_fact_check{epoch}')
    model.load_state_dict(checkpoint['model_state_dict'])
    _ = model.eval()
    print(f'Epoch {epoch}')
    test(model, dev_list)

Epoch 0


 11%|████▏                                 | 1338/12317 [00:19<02:40, 68.55it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1328 > 1024). Running this sequence through the model will result in indexing errors
100%|█████████████████████████████████████| 12317/12317 [02:55<00:00, 70.13it/s]


Precision: 0.8860457462563766
Recall: 0.9872570590392372
F1: 0.9339172664990026
Skipped: 24
Epoch 1


100%|█████████████████████████████████████| 12317/12317 [02:54<00:00, 70.69it/s]


Precision: 0.8957576758005943
Recall: 0.9839528558476881
F1: 0.9377862265618249
Skipped: 24
Epoch 2


100%|█████████████████████████████████████| 12317/12317 [02:55<00:00, 70.01it/s]

Precision: 0.9263289036544851
Recall: 0.9778186919165351
F1: 0.9513776337115073
Skipped: 24





### Uploading model 

In [5]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

epoch = 5
checkpoint = torch.load(f'save_fever{epoch}')
model.load_state_dict(checkpoint['model_state_dict'])
model.push_to_hub("fractalego/fact-checking")

'https://huggingface.co/fractalego/fact-checker/commit/a3185c8c177d8866908ea46c6b40abe9c7afddcb'

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.push_to_hub("fractalego/fact-checking")

'https://huggingface.co/fractalego/fact-checker/commit/ef06b4530a000f7671efed80a4440e752f2da351'