In [1]:
import json
import torch
import torch.nn as nn
import random
import pandas as pd

from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
train_list = json.load(open('../data/question_to_statement_train.json'))
dev_list = json.load(open('../data/question_to_statement_dev.json'))

In [3]:
print(train_list[1])

Discussion:
Q: What statutory policy minimizes outward expansion of urban London?
A: the Metropolitan Green Belt
Q: Greater it is divided into what two groups of boroughs?
A: Inner London and Outer London
Q: Where is the centre of it said to be located?
A: Eleanor Cross at Charing Cross near the junction of Trafalgar Square and Whitehall

Statement:
The centre of London is said to be located by the Eleanor Cross in Charing Cross near the junction of Trafalgar Square and Whitehall.


In [25]:
def create_statement_text(discussion, statement):
    text = 'Discussion:\n'
    text += discussion
    text += '\n\n'
    text += 'Statement:\n'
    text += statement
    if text[-1] != '.':
        text += '.'
    return text

In [26]:
yes_df = pd.read_csv('../data/yes_list.csv')
no_df = pd.read_csv('../data/no_list.csv')

In [27]:
yes_list = []
for index, row in yes_df.iterrows():
    if index > 99:
        break
    yes_list.append(create_statement_text(row[0], row[1]))
no_list = []
for index, row in no_df.iterrows():
    if index > 99:
        break
    no_list.append(create_statement_text(row[0], row[1]))

In [30]:
train_list += yes_list
train_list += no_list
random.shuffle(train_list)

In [31]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [32]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [33]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [34]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

Skipped 0 out of 57120


In [35]:
train_batches = batchify(data, 2)

In [36]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

In [37]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    #test(model, dev_list[:2000])
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_statement' + str(epoch))
    scheduler.step()

100%|██████████| 28625/28625 [26:54<00:00, 17.73it/s]


Epoch: 0 Loss: 1.8245072971389804


100%|██████████| 28625/28625 [26:46<00:00, 17.81it/s]


Epoch: 1 Loss: 1.538909419853094


100%|██████████| 28625/28625 [27:11<00:00, 17.55it/s]


Epoch: 2 Loss: 1.3339170875434792


100%|██████████| 28625/28625 [27:00<00:00, 17.67it/s]


Epoch: 3 Loss: 1.1856001912704202


 23%|██▎       | 6591/28625 [06:10<20:38, 17.79it/s]


KeyboardInterrupt: 

# TODO: TESTING ON THE GENERATED STATEMENTS BELOW!!!

In [38]:
print(dev_list[0])

Discussion:
Q: Who built the famous decorated havelis in Rajasthan?
A: Rajput kings
Q: Jaipur is also known as what city?
A: the Pink City
Q: What are the notable houses in it made from?
A: a type of sandstone dominated by a pink hue

Statement:
Notable houses in Jaipur made from a type of sandstone dominated by a pink hue


In [39]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [54]:
def get_statement(item, max_num_questions=0, question_number=-1):
    text = 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    text += '\n\nStatement:\n'
    return text

In [55]:
def generate_answer(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [61]:
checkpoint = torch.load('save_statement' + str(2))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [80]:
index=0
number=11

small_text = get_statement(dev_dict['data'][index], 
                           max_num_questions=5,
                           question_number=number)
print(small_text)

Discussion:
Q: What did she do to try to make herself the same color as her sisters?
A: she painted herself
Q: Whose paint was it?
A: the farmer
Q: What did Cotton's mother and siblings do when they saw her painted orange?
A: they started laughing
Q: Where did Cotton's mother put her to clean the paint off?
A: a bucket of water
Q: What did the other cats do when Cotton emerged from the bucket of water?
A: licked her face
Q: Did they want Cotton to change the color of her fur?
A: no

Statement:



In [81]:
generate_answer(model, small_text)

'The other cats did not want Cotton to change the color of her fur.'

## Testing

In [42]:
import sys
import traceback

def test(model, data):
    model.eval()
    tp = 0
    fp = 0
    fn = 0

    skipped = 0

    for item in tqdm(data):
        expected = get_answer_from_text(item)
        predicted = ''
        try:
            predicted = generate_answer(model, item)
        except (IndexError, RuntimeError) as e:
            #print(str(e))
            #exc_type, exc_value, exc_traceback = sys.exc_info()
            #print(repr(traceback.extract_tb(exc_traceback)))
            skipped += 1
            continue

        if expected == predicted:
            tp += 1
        if expected == 'N' and predicted == 'Y':
            fp += 1
        if expected == 'Y' and predicted == 'N':
            fn += 1

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    print('Precision:', precision)
    print('Recall:', recall)
    print('F1:', f1)
    print('Skipped:', skipped)

In [16]:
def get_text_up_to_question(text):
    _claim_yn = 'The evidence supports the claim:\n'
    return text[:text.find(_claim_yn) + len(_claim_yn)]

In [17]:
def get_answer_from_text(text):
    _claim_yn = 'The evidence supports the claim:\n'
    pos = text.find(_claim_yn) + len(_claim_yn)
    return text[pos]

In [None]:
def generate_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 1
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_from_text(output)

In [131]:
def generate_full_answer(model, text):
    prompt = get_text_up_to_question(text)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 3
    tokens_length = tokens.shape[1]
    if tokens_length + _length >= 1024:
        raise RuntimeError('Text is longer than 1024')
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length, 
             pad_token_id=50256
    )
    score = model(output, labels=output)[0]
    out_text = tokenizer.decode(output[0][tokens_length:], skip_special_tokens=True)

    return out_text, float(score)

In [19]:
get_answer_from_text(dev_list[0])

'N'

In [22]:
generate_answer(model, dev_list[0])

'N'

In [43]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
for epoch in range(3, 5):
    checkpoint = torch.load(f'save_fever{epoch}')
    model.load_state_dict(checkpoint['model_state_dict'])
    _ = model.eval()
    print(f'Epoch {epoch}')
    test(model, dev_list)

  0%|          | 4/44035 [00:00<21:25, 34.24it/s]

Epoch 3


  0%|          | 47/44035 [00:00<13:16, 55.25it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1515 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 44035/44035 [12:13<00:00, 60.02it/s]
  0%|          | 0/44035 [00:00<?, ?it/s]

Precision: 0.9659403378237863
Recall: 0.9838546447315776
F1: 0.974815194832451
Skipped: 72
Epoch 4


100%|██████████| 44035/44035 [12:06<00:00, 60.61it/s]

Precision: 0.9719283593170007
Recall: 0.9799078427244872
F1: 0.975901790186007
Skipped: 72





## Tests with dev from COQA

In [81]:
epoch = 4
checkpoint = torch.load(f'save_fever{epoch}')
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [82]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [83]:
print(dev_dict['data'][0]['story'])

Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. 

"What are you doing, Cotton?!" 

"I only wanted to be more like you". 

Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be any o

In [153]:
story = dev_dict['data'][0]['story'].replace('\n\n', '\n').replace('\n', '')

text = f"""
Evidence:
{story}

Claim:
What color was Cotton? white.

The evidence supports the claim:
"""[1:]


In [154]:
print(text)

Evidence:
Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. "What are you doing, Cotton?!" "I only wanted to be more like you". Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be a

In [155]:
generate_full_answer(model, text)

('Yes.\n\n', 7.758946418762207)