In [13]:
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%matplotlib inline
plt.switch_backend('agg')

Grammar Error Correction with Transformers
======================================================

This notebook shows how to train a GEC model based on transformers.

Table of Contents

1) [Data Sourcing and Processing](#1)
    - Tokenizing and Embedding
    - Collation
2) [Seq2Seq Network using Transformer](#2)
    - Positional encoding
    - Multi-head attention
3) [Model definition](#3)
4) [Training](#4)
5) [Evaluation](#5)
6) [References](#6)

Data Sourcing and Processing
----------------------------

<a id="1"></a>C4 200M dataset from Google Research is used in this notebook. You can find more information about the C4 200M dataset on Google Research's [BEA 2021 paper](https://aclanthology.org/2021.bea-1.4/) (Stahlberg and Kumar, 2021).

The already [processed dataset](https://huggingface.co/datasets/liweili/c4_200m) was extracted from Huggingface in CSV format, then was transformed to HDF5 format for better manageability. The conversion process is detailed in ``utils.py``, and was based on this [notebook](https://github.com/rasbt/deeplearning-models/blob/master/pytorch_ipynb/mechanics/custom-data-loader-csv.ipynb).
The final version of the dataset is uploaded on [Kaggle](https://www.kaggle.com/datasets/dariocioni/c4200m).

A custom class ``Hdf5Dataset`` based on ``torch.utils.data.Dataset`` is developed, which yields a pair of source-target raw sentences.

| source                                             | target                                                  |
|----------------------------------------------------|---------------------------------------------------------|
| Much many brands and sellers still in the market.  | Many brands and sellers still in the market.            |
| She likes playing in park and come here every week | She likes playing in the park and comes here every week |

In [14]:
# Import libraries
import torch
import pandas as pd
import numpy as np
import pathlib as pl

In [15]:
import h5py
from torch.utils.data import Dataset,IterableDataset
random.seed(42)

class Hdf5Dataset(Dataset):
    """Custom Dataset for loading entries from HDF5 databases"""

    def __init__(self, h5_path, transform=None,num_entries = None,randomized=False):

        self.h5f = h5py.File(h5_path, 'r')
        self.size = self.h5f['labels'].shape[0]
        self.transform = transform
        self.randomized = randomized
        self.max_index = num_entries if num_entries is not None else self.size
        #Chooses an offset for the dataset when using a subset of a Hdf5 file
        if randomized:
            self.offset = random.choice(range(0,self.size//self.max_index))*self.max_index
        else:
            self.offset = 0


    def __getitem__(self, index):
        if index > self.max_index:
            raise StopIteration
        input = self.h5f['input'][self.offset+index].decode('utf-8')
        label = self.h5f['labels'][self.offset+index].decode('utf-8')
        if self.transform is not None:
            features = self.transform(input)
        return input, label

    def __len__(self):
        return self.max_index

    def reshuffle(self):
        if self.randomized:
            self.offset = random.choice(range(0,self.size//self.max_index))*self.max_index
        else:
            print("Please set randomized=True")

In [16]:
from typing import Iterable, List
from tqdm import tqdm
import pathlib as pl
from torchtext.data import get_tokenizer

SRC_LANGUAGE = 'incorrect'
TGT_LANGUAGE = 'correct'
# MAX_LENGTH = 512
VOCAB_SIZE = 20000
N_SAMPLES = 100000

# # Define special symbols and indices
UNK_IDX,PAD_IDX, BOS_IDX, EOS_IDX = 0,1,2,3
# # Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<pad>','<unk>', '[CLS]', '[SEP]']

# Place-holders
token_transform =get_tokenizer('basic_english')
vocab_transform = None

folder = 'D:\Datasets\c4_200m\data\hdf5'
train_filename = 'C4_200M.hf5-00000-of-00010'
valid_filename = 'C4_200M.hf5-00001-of-00010'
vocab_path = 'vocab/vocab_20K.pth'
embedding_path = 'vocab/glove_42B_300d_20K.pth'
checkpoint_folder = 'D:\Datasets\c4_200m\checkpoints'

##COLAB
# folder = '/content/drive/MyDrive/Datasets/c4_200m/hdf5'
# train_filename = 'C4_200M.hf5-00000-of-00010'
# valid_filename = 'C4_200M.hf5-00001-of-00010'
# embedding_path = '/content/drive/MyDrive/Colab Notebooks/GEC_Soft_Masked_BERT/vocab/glove_42B_300d_20K.pth'
# vocab_path = '/content/drive/MyDrive/Colab Notebooks/GEC_Soft_Masked_BERT/vocab/vocab_20K.pth'
# checkpoint_folder = '/content/drive/MyDrive/Colab Notebooks/GEC_Soft_Masked_BERT/checkpoints'

### Tokenizing and Embedding
Data is then tokenized by the standard tokenizer from ``torchtext`` library, which performs basic normalization and splitting by space. Normalization includes
- lowercasing
- complete some basic text normalization for English words as follows:
    add spaces before and after '\''
    remove '\"',
    add spaces before and after '.'
    replace '<br \/>'with single space
    add spaces before and after ','
    add spaces before and after '('
    add spaces before and after ')'
    add spaces before and after '!'
    add spaces before and after '?'
    replace ';' with single space
    replace ':' with single space
    replace multiple spaces with single space

A vocabulary was produced based on 1M samples of the training dataset, using ``build_vocab`` function inside ``vocab.py``.

I then evaluated pre-trained embeddings and confronted with the nn.Embedding, both with embeddings of length 300.
- ``GloVe`` Embeddings were trained on Common Crawl (42B tokens, 1.9M vocab, uncased, 300d vectors)
- Pretrained embeddings were aligned with the vocabulary using ``load_pretrained_embs`` function inside ``vocab.py``.

In [17]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<UNK>', '<PAD>', '<BOS>', '<EOS>']

vocab_transform = torch.load(vocab_path)
embeddings = torch.load(embedding_path)

### Collation

An iterator over ``Hdf5dataset`` yields a pair of raw strings.
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network.
Below I defined a collate function that converts batch of raw strings into batch tensors that can be fed directly into the model.

In [18]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# def glove_transform(tokens: List[str]):



# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = sequential_transforms(token_transform,
                                       vocab_transform,
                                       tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform(src_sample.rstrip("\n")))
        tgt_batch.append(text_transform(tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

Seq2Seq Network using Transformer
---------------------------------

<a id="2"></a>Transformer is a Seq2Seq model introduced in [“Attention is all you need”](<https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>) paper for solving machine translation tasks.
Below, we will create a Seq2Seq network that uses Transformer. The network consists of three parts:
1) The embedding layer. This layer converts tensor of input indices into corresponding tensor of input embeddings.
    These embedding are further augmented with ``Positional Encodings``, to provide position information of input tokens to the model.
2) The actual [Transformer](<https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>) model.
3) The output of Transformer model is finally passed through a linear layer that give un-normalized probabilities for each token in the target language.

### Positional Encoding
Differently from RNNs, Transformers don't have a notion of relative or absolute position of the tokens in the input.
One solution is to combine the input embeddings with positional embeddings, specific to each position in an input sequence.
A solution that is not biased towards the initial positions consists in a combination of sine and cosine functions of different frequencies ([Vaswani et al. ,2017](<https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>))

Given an embedding of length $d$, a position in the sequence $pos$ and the $i$-th dimension of the embedding, the position embedding is calculated as

$$PE_{(pos,2i)} = \sin(pos/10000^{2i/d}\quad,\quad PE_{(pos,2i+1)} = \cos(pos/10000^{2i/d})$$

Dropout is also added to the sums of the embeddings and the positional encodings in both the encoder and decoder.

<img src="img/pos_enc.png">

### Multi-head attention
A single transformer block cannot capture all the different kinds of simultaneous relations among its inputs.
To address this problem, Transformers can use multiple self-attention heads, residing in parallel layers and with different parameter sets.
Each head $i$ will have a different set of key, query and value matrices $W_i^K,W_i^Q,W_i^V$ and will project into different embeddings for each head.
The different embeddings are finally reduced to the original input dimension, using a trainable linear projection $W^O$

<img src="img/multihead.png">




In [19]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size,embedding_weights=None):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        if embedding_weights is not None:
            self.embedding.weight = torch.nn.Parameter(torch.from_numpy(embedding_weights).float())
            # self.embedding.weight.requires_grad =False
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network 
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 100,
                 dropout: float = 0.1,
                 embedding_weights = None):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size,embedding_weights)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size,embedding_weights)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

During training, we need a subsequent word mask that will prevent model to look into the future words when making predictions. We will also need masks to hide source and target padding tokens. Below, let's define a function that will take care of both.




In [20]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

## Model definition
<a id="3"></a>The model is instantiated in the ``Seq2SeqTransformer`` wrapper.
The used loss function is the cross-entropy loss and the optimizer used for training is Adam. The hyperparameters are the same used in ([Vaswani et al. ,2017](<https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>))

In [21]:
torch.manual_seed(0)

VOCAB_SIZE = len(vocab_transform.vocab.itos_)
EMB_SIZE = 300
NHEAD = 2
FFN_HID_DIM = 512
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, VOCAB_SIZE, VOCAB_SIZE, FFN_HID_DIM,embedding_weights=embeddings)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [22]:
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Hdf5Dataset(pl.Path(folder)/train_filename,num_entries=N_SAMPLES,randomized=True)
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in tqdm(train_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Hdf5Dataset(pl.Path(folder)/valid_filename,num_entries=N_SAMPLES)
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in tqdm(val_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        with torch.no_grad():
            logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

## Training
<a href="4"></a>Now we have all the ingredients to train our model

In [23]:
from timeit import default_timer as timer
NUM_EPOCHS = 10

train_losses = []
val_losses = []

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    torch.save({
        'epoch': epoch,
        'model_state_dict': transformer.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': val_loss,
    }, pl.Path(checkpoint_folder)/"transformer_model_glove.pt")

  0%|          | 27/6250 [00:02<08:02, 12.89it/s]


KeyboardInterrupt: 

In [None]:
_ = plt.plot(train_losses)
_ = plt.plot(val_losses)

## Evaluation
<a id="5"></a>We can evaluate the produced correction using the ``greedy_decode`` function.

In [None]:
import re
# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob,dim=1)
        next_word = next_word.item()
        if next_word == UNK_IDX:
            _, next_words = prob.topk(k=5,dim=1)
            for i in next_words.cpu().detach().numpy()[0]:
                if i != UNK_IDX:
                    next_word = i.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to correct input sentence
def correct(src_sentence: str, model: torch.nn.Module):
    model.eval()
    src = text_transform(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return ' '.join([vocab_transform.vocab.itos_[i] for i in tgt_tokens if i not in [PAD_IDX,BOS_IDX,EOS_IDX]])

In [None]:
checkpoint = torch.load(pl.Path(checkpoint_folder)/"transformer_model_nopt.pt")
transformer.load_state_dict(checkpoint['model_state_dict'])

transformer.eval()

# Pick one in 18M examples
val_iter = Hdf5Dataset(pl.Path(folder)/valid_filename,num_entries=None)

src,trg = random.choice(val_iter)

print("input: \"",src,"\"")
print("target: \"",trg,"\"")

print("prediction: \"",correct(src,transformer),"\"")

<a id="6">References</a>
----------

1. [Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
2. [The annotated transformer](http://nlp.seas.harvard.edu/annotated-transformer)
3. [Pytorch tutorial on NMT with transformers](https://pytorch.org/tutorials/beginner/translation_transformer.html)
4. [Speech and Language Processing, Jurafsky and Martin](https://web.stanford.edu/~jurafsky/slp3/)

