In [1]:
# https://github.com/AMontgomerie/question_generator/blob/master/training/qg_training.ipynb

In [2]:
import os
import sys
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
from tqdm.notebook import tqdm

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

In [4]:
from datasets import load_dataset

dataset = load_dataset("race", "high")

Reusing dataset race (C:\Users\Kevin\.cache\huggingface\datasets\race\high\0.1.0\5a80ba2d003e023fdce95d01c1b02f5a70d5eb2375465bee162baf9824c91474)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
next(iter(dataset["train"]))

{'example_id': 'high1.txt',
 'article': 'My husband is a born shopper. He loves to look at things and to touch them. He likes to compare prices between the same items in different shops. He would never think of buying anything without looking around in several different shops. On the other hand, I\'m not a shopper. I think shopping is boring and unpleasant. If I like something and I have enough money to take it, I buy it at once. I never look around for a good price or a better deal. Of course my husband and I never go shopping together. Doing shopping together would be too painful for both of us. When it comes to shopping, we go our different ways.\nSometimes I ask my son Jimmy to buy some food in the shop not far from our home. But he is always absent-minded. This was his story.\nOne day I said to him, " I hope you won\'t forget what I have told you to buy." " No," said Jimmy. "I won\'t forget. You want three oranges , six eggs and a pound of meat."\nHe went running down the street t

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [7]:
PRETRAINED_MODEL = 't5-base'
DIR = "question_generator/"
BATCH_SIZE = 1
SEQ_LENGTH = 512

tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
)

# class QGDataset(Dataset):
#     def __init__(self, csv):
#         self.df = pd.read_csv(csv, engine='python')

#     def __len__(self):
#          return len(self.df)

#     def __getitem__(self, idx):   
#         if torch.is_tensor(idx):
#             idx = idx.tolist()
#         row = self.df.iloc[idx, 1:]       

#         encoded_text = tokenizer(
#             row['text'], 
#             pad_to_max_length=True, 
#             max_length=SEQ_LENGTH,
#             truncation=True,
#             return_tensors="pt"
#         )
#         encoded_text['input_ids'] = torch.squeeze(encoded_text['input_ids'])
#         encoded_text['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

#         encoded_question = tokenizer(
#             row['question'],
#             pad_to_max_length=True,
#             max_length=SEQ_LENGTH,
#             truncation=True,
#             return_tensors='pt'
#         )
#         encoded_question['input_ids'] = torch.squeeze(encoded_question['input_ids'])

#         return (encoded_text.to(device), encoded_question.to(device))

# train_set = QGDataset(os.path.join(DIR, 'question_generator/datasets/qg_train.csv'))
# train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
# valid_set = QGDataset(os.path.join(DIR, 'question_generator/datasets/qg_valid.csv')) 
# valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False)



2

In [8]:
def make_text(row):    
    encoded = {}
    encoded_text = tokenizer(
        row['answer'] + row['article'], 
        pad_to_max_length=True, 
        max_length=SEQ_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    encoded['input_ids'] = torch.squeeze(encoded_text['input_ids'])
    encoded['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

    encoded_question = tokenizer(
        row['question'],
        pad_to_max_length=True,
        max_length=SEQ_LENGTH,
        truncation=True,
        return_tensors='pt'
    )
    encoded['input_ids_question'] = torch.squeeze(encoded_question['input_ids'])
    return encoded

dataset = dataset.map(make_text)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'input_ids_question'])
train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset["validation"], batch_size=BATCH_SIZE, shuffle=True)

Loading cached processed dataset at C:\Users\Kevin\.cache\huggingface\datasets\race\high\0.1.0\5a80ba2d003e023fdce95d01c1b02f5a70d5eb2375465bee162baf9824c91474\cache-7069e3a27b93c19d.arrow
Loading cached processed dataset at C:\Users\Kevin\.cache\huggingface\datasets\race\high\0.1.0\5a80ba2d003e023fdce95d01c1b02f5a70d5eb2375465bee162baf9824c91474\cache-dc6faf71a15bd577.arrow
Loading cached processed dataset at C:\Users\Kevin\.cache\huggingface\datasets\race\high\0.1.0\5a80ba2d003e023fdce95d01c1b02f5a70d5eb2375465bee162baf9824c91474\cache-d8c2afa6b62f65ed.arrow


In [9]:
LR = 0.001
EPOCHS = 20
LOG_INTERVAL = 5000

config = T5Config(decoder_start_token_id=tokenizer.pad_token_id)
model = T5ForConditionalGeneration(config).from_pretrained(PRETRAINED_MODEL)
model.resize_token_embeddings(len(tokenizer)) # to account for new special tokens
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

In [10]:
SAVED_MODEL_PATH = "question_generator/qg_pretrained_t5_model_trained.pth"

def train(epoch, best_val_loss):
    model.train()
    total_loss = 0.
    for batch_index, batch in tqdm(enumerate(train_loader)):
        target = {
            'input_ids': batch['input_ids_question'].to(device)
        }
        data = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }
        optimizer.zero_grad()
        masked_labels = mask_label_padding(target['input_ids'])
        output = model(
            input_ids=data['input_ids'],
            attention_mask=data['attention_mask'],
            labels=masked_labels
        )
        loss = output[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch_index % LOG_INTERVAL == 0 and batch_index > 0:
            cur_loss = total_loss / LOG_INTERVAL
            print('| epoch {:3d} | ' 
                  '{:5d}/{:5d} batches | '
                  'loss {:5.2f}'.format(
                    epoch, 
                    batch_index, len(train_loader), 
                    cur_loss))
            save(
                SAVED_MODEL_PATH,
                epoch, 
                model.state_dict(), 
                optimizer.state_dict(), 
                best_val_loss
            )
            total_loss = 0

def evaluate(eval_model, data_loader):
    eval_model.eval()
    total_loss = 0.
    with torch.no_grad():
        for batch_index, batch in tqdm(enumerate(data_loader)):
            target = {
                'input_ids': batch['input_ids_question'].to(device)
            }
            data = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            masked_labels = mask_label_padding(target['input_ids'])
            output = eval_model(
                input_ids=data['input_ids'],
                attention_mask=data['attention_mask'],
                labels=masked_labels
            )
            total_loss += output[0].item()
    return total_loss / len(data_loader)

def mask_label_padding(labels):
    MASK_ID = -100
    labels[labels==tokenizer.pad_token_id] = MASK_ID
    return labels

def save(path, epoch, model_state_dict, optimizer_state_dict, loss):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
            'best_loss': loss,
            }, path)

def load(path):
    return torch.load(path)

def print_line():
    LINE_WIDTH = 60
    print('-' * LINE_WIDTH)

In [None]:
best_val_loss = float("inf")
best_model = None

saved = load(SAVED_MODEL_PATH)
model.load_state_dict(saved["model_state_dict"])
optimizer.load_state_dict(saved["optimizer_state_dict"])

# val_loss = evaluate(model, valid_loader)
# print_line()
# print('| Before training | valid loss {:5.2f}'.format(
#     val_loss)
# )
print_line()

for epoch in range(1, EPOCHS + 1):

    train(epoch, best_val_loss)
    torch.cuda.empty_cache()
    val_loss = evaluate(model, valid_loader)
    torch.cuda.empty_cache()
    print_line()
    print('| end of epoch {:3d} | valid loss {:5.2f}'.format(
        epoch,
        val_loss)
    )
    print_line()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model
        save(
             SAVED_MODEL_PATH,
             epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             best_val_loss
        )
        print("| Model saved.")
        print_line()

------------------------------------------------------------


0it [00:00, ?it/s]

| epoch   1 |  5000/62445 batches | loss  2.30
| epoch   1 | 10000/62445 batches | loss  2.28
| epoch   1 | 15000/62445 batches | loss  2.25
| epoch   1 | 20000/62445 batches | loss  2.24
| epoch   1 | 25000/62445 batches | loss  2.21
| epoch   1 | 30000/62445 batches | loss  2.21
| epoch   1 | 35000/62445 batches | loss  2.21
| epoch   1 | 40000/62445 batches | loss  2.18
| epoch   1 | 45000/62445 batches | loss  2.16
| epoch   1 | 50000/62445 batches | loss  2.15
| epoch   1 | 55000/62445 batches | loss  2.15
| epoch   1 | 60000/62445 batches | loss  2.14


0it [00:00, ?it/s]

------------------------------------------------------------
| end of epoch   1 | valid loss  1.99
------------------------------------------------------------
| Model saved.
------------------------------------------------------------


0it [00:00, ?it/s]

| epoch   2 |  5000/62445 batches | loss  2.12
| epoch   2 | 10000/62445 batches | loss  2.12
| epoch   2 | 15000/62445 batches | loss  2.12
| epoch   2 | 20000/62445 batches | loss  2.09
| epoch   2 | 25000/62445 batches | loss  2.09
| epoch   2 | 30000/62445 batches | loss  2.08
| epoch   2 | 35000/62445 batches | loss  2.07
| epoch   2 | 40000/62445 batches | loss  2.08
| epoch   2 | 45000/62445 batches | loss  2.07
| epoch   2 | 50000/62445 batches | loss  2.05
| epoch   2 | 55000/62445 batches | loss  2.05
| epoch   2 | 60000/62445 batches | loss  2.03


0it [00:00, ?it/s]

------------------------------------------------------------
| end of epoch   2 | valid loss  1.92
------------------------------------------------------------
| Model saved.
------------------------------------------------------------


0it [00:00, ?it/s]

| epoch   3 |  5000/62445 batches | loss  2.03
| epoch   3 | 10000/62445 batches | loss  2.03
| epoch   3 | 15000/62445 batches | loss  2.03
| epoch   3 | 20000/62445 batches | loss  2.02
| epoch   3 | 25000/62445 batches | loss  2.02
| epoch   3 | 30000/62445 batches | loss  2.01
| epoch   3 | 35000/62445 batches | loss  2.01
| epoch   3 | 40000/62445 batches | loss  2.00
| epoch   3 | 45000/62445 batches | loss  1.98
| epoch   3 | 50000/62445 batches | loss  1.98
| epoch   3 | 55000/62445 batches | loss  1.98
| epoch   3 | 60000/62445 batches | loss  1.97


0it [00:00, ?it/s]

------------------------------------------------------------
| end of epoch   3 | valid loss  1.87
------------------------------------------------------------
| Model saved.
------------------------------------------------------------


0it [00:00, ?it/s]

| epoch   4 |  5000/62445 batches | loss  1.97
| epoch   4 | 10000/62445 batches | loss  1.98
| epoch   4 | 15000/62445 batches | loss  1.96
| epoch   4 | 20000/62445 batches | loss  1.98
| epoch   4 | 25000/62445 batches | loss  1.98
| epoch   4 | 30000/62445 batches | loss  1.94
| epoch   4 | 35000/62445 batches | loss  1.95
| epoch   4 | 40000/62445 batches | loss  1.96
| epoch   4 | 45000/62445 batches | loss  1.94
| epoch   4 | 50000/62445 batches | loss  1.95
| epoch   4 | 55000/62445 batches | loss  1.95
| epoch   4 | 60000/62445 batches | loss  1.95


0it [00:00, ?it/s]

------------------------------------------------------------
| end of epoch   4 | valid loss  1.83
------------------------------------------------------------
| Model saved.
------------------------------------------------------------


0it [00:00, ?it/s]

| epoch   5 |  5000/62445 batches | loss  1.93
| epoch   5 | 10000/62445 batches | loss  1.94
| epoch   5 | 15000/62445 batches | loss  1.94
| epoch   5 | 20000/62445 batches | loss  1.93
| epoch   5 | 25000/62445 batches | loss  1.94
| epoch   5 | 30000/62445 batches | loss  1.92
| epoch   5 | 35000/62445 batches | loss  1.93
| epoch   5 | 40000/62445 batches | loss  1.91
| epoch   5 | 45000/62445 batches | loss  1.92
| epoch   5 | 50000/62445 batches | loss  1.90
| epoch   5 | 55000/62445 batches | loss  1.93
| epoch   5 | 60000/62445 batches | loss  1.91


0it [00:00, ?it/s]

------------------------------------------------------------
| end of epoch   5 | valid loss  1.81
------------------------------------------------------------
| Model saved.
------------------------------------------------------------


0it [00:00, ?it/s]

| epoch   6 |  5000/62445 batches | loss  1.90
| epoch   6 | 10000/62445 batches | loss  1.88
