In [1]:
import json
import torch
import torch.nn as nn
import random
import pandas as pd

from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
train_list = json.load(open('../data/question_to_statement_train.json'))
dev_list = json.load(open('../data/question_to_statement_dev.json'))

In [3]:
print(train_list[1])

Discussion:
Q: What statutory policy minimizes outward expansion of urban London?
A: the Metropolitan Green Belt
Q: Greater it is divided into what two groups of boroughs?
A: Inner London and Outer London
Q: Where is the centre of it said to be located?
A: Eleanor Cross at Charing Cross near the junction of Trafalgar Square and Whitehall

Statement:
The centre of London is said to be located by the Eleanor Cross in Charing Cross near the junction of Trafalgar Square and Whitehall.


In [4]:
def create_statement_text(discussion, statement):
    text = 'Discussion:\n'
    text += discussion
    text += '\n\n'
    text += 'Statement:\n'
    text += statement
    if text[-1] != '.':
        text += '.'
    return text

In [5]:
yes_df = pd.read_csv('../data/yes_list.csv')
no_df = pd.read_csv('../data/no_list.csv')

In [6]:
yes_list = []
for index, row in yes_df.iterrows():
    if index > 99:
        break
    yes_list.append(create_statement_text(row[0], row[1]))
no_list = []
for index, row in no_df.iterrows():
    if index > 99:
        break
    no_list.append(create_statement_text(row[0], row[1]))

In [7]:
train_list += yes_list
train_list += no_list
random.shuffle(train_list)

In [8]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [9]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [10]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

In [11]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

Skipped 0 out of 57120


In [12]:
train_batches = batchify(data, 2)

In [13]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

In [14]:
epochs = 5

In [15]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(epochs):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    #test(model, dev_list[:2000])
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_statement' + str(epoch))
    scheduler.step()

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 28625/28625 [27:00<00:00, 17.67it/s]


Epoch: 0 Loss: 1.794938656511265


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 28625/28625 [26:59<00:00, 17.67it/s]


Epoch: 1 Loss: 1.4444590707210474


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 28625/28625 [27:00<00:00, 17.67it/s]


Epoch: 2 Loss: 1.183250183196047


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 28625/28625 [26:59<00:00, 17.67it/s]


Epoch: 3 Loss: 1.009052302759287


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 28625/28625 [26:58<00:00, 17.68it/s]


Epoch: 4 Loss: 0.8660318956614582


### Testing

In [16]:
_limit = 1024
data = []
total_skipped = 0
for item in dev_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        total_skipped += 1
        continue
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(dev_list)}')

dev_batches = batchify(data, 2)

Skipped 0 out of 8275


In [19]:
def test(model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.eval()
        inputs = batch
        loss = model(inputs.cuda(), labels=inputs.cuda())[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [20]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
for epoch in range(epochs):
    checkpoint = torch.load(f'save_statement{epoch}')
    model.load_state_dict(checkpoint['model_state_dict'])
    _ = model.eval()
    print(f'Epoch {epoch}')
    loss = test(model, dev_batches)
    print("Loss", loss)

Epoch 0


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4182/4182 [00:41<00:00, 99.82it/s]


Loss 2.0136151553083983
Epoch 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4182/4182 [00:42<00:00, 99.04it/s]


Loss 2.1651509036748733
Epoch 2


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4182/4182 [00:42<00:00, 98.04it/s]


Loss 2.355712597461935
Epoch 3


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4182/4182 [00:42<00:00, 98.41it/s]


Loss 2.4818756203797374
Epoch 4


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4182/4182 [00:42<00:00, 98.28it/s]

Loss 2.594299618320451





save 0 dev was ~3