In [153]:
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%matplotlib inline
plt.switch_backend('agg')

Grammar Error Correction with nn.Transformer and torchtext
======================================================

This notebook shows how to train a GEC model based on transformers.

Data Sourcing and Processing
----------------------------

C4 200M dataset from Google Research is used in this notebook. You can find more information about the C4 200M dataset on GR's [BEA 2021 paper](https://aclanthology.org/2021.bea-1.4/).
The already [processed dataset](https://huggingface.co/datasets/liweili/c4_200m) was extracted from Huggingface, then was transformed to HDF5 format for better manageability. The conversion process was based on this [notebook](https://github.com/rasbt/deeplearning-models/blob/master/pytorch_ipynb/mechanics/custom-data-loader-csv.ipynb).
The final version of the dataset is uploaded on [Kaggle](https://www.kaggle.com/datasets/dariocioni/c4200m).

A custom class ``Hdf5Dataset`` based on ``torch.utils.data.Dataset`` is developed, which yields a pair of source-target raw sentences.

| source                                             | target                                                  |
|----------------------------------------------------|---------------------------------------------------------|
| Much many brands and sellers still in the market.  | Many brands and sellers still in the market.            |
| She likes playing in park and come here every week | She likes playing in the park and comes here every week |

In [154]:
# Import libraries
import torch
import pandas as pd
import numpy as np
import pathlib as pl

In [155]:
import h5py
from torch.utils.data import Dataset

class Hdf5Dataset(Dataset):
    """Custom Dataset for loading entries from HDF5 databases"""

    def __init__(self, h5_path, transform=None,num_entries = None):

        self.h5f = h5py.File(h5_path, 'r')
        if num_entries:
            self.num_entries = num_entries
        else:
            self.num_entries = self.h5f['labels'].shape[0]
        self.transform = transform

    def __getitem__(self, index):
        if index > self.num_entries:
            raise StopIteration
        input = self.h5f['input'][index].decode('utf-8')
        label = self.h5f['labels'][index].decode('utf-8')
        if self.transform is not None:
            features = self.transform(input)
        return input, label

    def __len__(self):
        return self.num_entries

In [156]:
from typing import Iterable, List
from tqdm import tqdm
import pathlib as pl
from torchtext.data import get_tokenizer

# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, index: int) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}
    for data_sample in tqdm(data_iter):
        if data_sample[index] and isinstance(data_sample[index],str):
            yield token_transform(data_sample[index])

SRC_LANGUAGE = 'incorrect'
TGT_LANGUAGE = 'correct'
# MAX_LENGTH = 512
VOCAB_SIZE = 20000
N_SAMPLES = 100000

# # Define special symbols and indices
UNK_IDX,PAD_IDX, BOS_IDX, EOS_IDX = 0,1,2,3
# # Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<pad>','<unk>', '[CLS]', '[SEP]']

# Place-holders
token_transform =get_tokenizer('basic_english')
vocab_transform = None

folder = 'D:\Datasets\c4_200m\data\hdf5'
train_filename = 'C4_200M.hf5-00000-of-00010'
valid_filename = 'C4_200M.hf5-00001-of-00010'
embedding_path = 'D:\Datasets\glove\glove.42B.300d.txt'
checkpoint_folder = 'D:\Datasets\c4_200m\checkpoints'

## Tokenizing and Embedding
Data is then tokenized by a pre-trained ``BertTokenizer`` from HuggingFace's ``transformers`` library, based on a Wordpiece tokenization.
The BERT model was pretrained on [BookCorpus](https://yknzhu.wixsite.com/mbweb), a dataset consisting of 11,038 unpublished books and [English Wikipedia](https://en.wikipedia.org/wiki/English_Wikipedia) (excluding lists, tables and headers).

In [157]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<UNK>', '<PAD>', '<BOS>', '<EOS>']

vocab_transform = torch.load('vocab/vocab_20K.pth')
embeddings = torch.load('vocab/glove_42B_300d_20K.pth')

Collation
---------

An iterator over ``Hdf5dataset`` yields a pair of raw strings.
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network.
Below we define our collate function that convert batch of raw strings into batch tensors that can be fed directly into our model.

In [158]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# def glove_transform(tokens: List[str]):



# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = sequential_transforms(token_transform,
                                       vocab_transform,
                                       tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform(src_sample.rstrip("\n")))
        tgt_batch.append(text_transform(tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

Seq2Seq Network using Transformer
---------------------------------

Transformer is a Seq2Seq model introduced in [“Attention is all you need”](<https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>) paper for solving machine translation tasks.
Below, we will create a Seq2Seq network that uses Transformer. The network consists of three parts. First part is the embedding layer. This layer converts tensor of input indices into corresponding tensor of input embeddings. These embedding are further augmented with positionalencodings to provide position information of input tokens to the model. The second part is the actual [Transformer](<https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>) model. Finally, the output of Transformer model is passed through linear layer that give un-normalized probabilities for each token in the target language.




In [159]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size,embedding_weights=None):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        if embedding_weights is not None:
            self.embedding.weight = torch.nn.Parameter(torch.from_numpy(embedding_weights).float())
            # self.embedding.weight.requires_grad =False
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network 
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 100,
                 dropout: float = 0.1,
                 embedding_weights = None):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size,embedding_weights)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size,embedding_weights)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

During training, we need a subsequent word mask that will prevent model to look into the future words when making predictions. We will also need masks to hide source and target padding tokens. Below, let's define a function that will take care of both.




In [160]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Let's now define the parameters of our model and instantiate the same. Below, we also define our loss function which is the cross-entropy loss and the optmizer used for training.




In [161]:
torch.manual_seed(0)

VOCAB_SIZE = len(vocab_transform.vocab.itos_)
EMB_SIZE = 300
NHEAD = 2
FFN_HID_DIM = 512
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, VOCAB_SIZE, VOCAB_SIZE, FFN_HID_DIM,embedding_weights=embeddings)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [162]:
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Hdf5Dataset(pl.Path(folder)/train_filename,num_entries=N_SAMPLES)
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in tqdm(train_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Hdf5Dataset(pl.Path(folder)/valid_filename,num_entries=N_SAMPLES)
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in tqdm(val_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        with torch.no_grad():
            logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

Now we have all the ingredients to train our model. Let's do it!




In [163]:
from timeit import default_timer as timer
NUM_EPOCHS = 10

train_losses = []
val_losses = []

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    torch.save({
        'epoch': epoch,
        'model_state_dict': transformer.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': val_loss,
    }, pl.Path('checkpoints')/"transformer_model_glove.pt")

100%|██████████| 6250/6250 [07:15<00:00, 14.34it/s]
100%|██████████| 6250/6250 [04:12<00:00, 24.78it/s]


Epoch: 1, Train loss: 4.725, Val loss: 3.231, Epoch time = 435.777s


100%|██████████| 6250/6250 [07:13<00:00, 14.43it/s]
100%|██████████| 6250/6250 [04:05<00:00, 25.42it/s]


Epoch: 2, Train loss: 2.974, Val loss: 2.419, Epoch time = 433.229s


100%|██████████| 6250/6250 [07:03<00:00, 14.76it/s]
100%|██████████| 6250/6250 [03:58<00:00, 26.18it/s]


Epoch: 3, Train loss: 2.354, Val loss: 2.074, Epoch time = 423.424s


100%|██████████| 6250/6250 [07:11<00:00, 14.47it/s]
100%|██████████| 6250/6250 [04:05<00:00, 25.46it/s]


Epoch: 4, Train loss: 2.068, Val loss: 1.940, Epoch time = 431.982s


100%|██████████| 6250/6250 [07:07<00:00, 14.61it/s]
100%|██████████| 6250/6250 [04:11<00:00, 24.84it/s]


Epoch: 5, Train loss: 1.936, Val loss: 1.882, Epoch time = 427.724s


100%|██████████| 6250/6250 [07:51<00:00, 13.26it/s]
100%|██████████| 6250/6250 [04:48<00:00, 21.68it/s]


Epoch: 6, Train loss: 1.852, Val loss: 1.842, Epoch time = 471.273s


  1%|          | 35/6250 [00:02<08:25, 12.30it/s]


KeyboardInterrupt: 

In [None]:
_ = plt.plot(train_losses)
_ = plt.plot(val_losses)

In [248]:
import re
# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob,dim=1)
        next_word = next_word.item()
        if next_word == UNK_IDX:
            _, next_words = prob.topk(k=5,dim=1)
            for i in next_words.cpu().detach().numpy()[0]:
                if i != UNK_IDX:
                    next_word = i.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to correct input sentence
def correct(src_sentence: str, model: torch.nn.Module):
    model.eval()
    src = text_transform(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return ' '.join([vocab_transform.vocab.itos_[i] for i in tgt_tokens if i not in [PAD_IDX,BOS_IDX,EOS_IDX]])

In [250]:
checkpoint = torch.load(pl.Path('checkpoints')/"transformer_model_glove.pt")
transformer.load_state_dict(checkpoint['model_state_dict'])

transformer.eval()

# Pick one in 18M examples
val_iter = Hdf5Dataset(pl.Path(folder)/valid_filename,num_entries=None)

src,trg = random.choice(val_iter)

print("input: \"",src,"\"")
print("target: \"",trg,"\"")

print("prediction: \"",correct(src,transformer),"\"")

input: " In Windows, I have to used my mouse to navigate a list of folder. "
target: " On Windows, I have to use the mouse to navigate a folder list. "
prediction: " in windows , i have to use my mouse to navigate a list of folder . "


References
----------

1. [Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
2. [The annotated transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding)
3. [Pytorch tutorial on NMT with transformers](https://pytorch.org/tutorials/beginner/translation_transformer.html)

