In [1]:
%matplotlib inline



---


# Lexical word order with Transformers





## Data loading and dataset creation




In [2]:

paths = ["./web2.txt"]

with open(paths[0]) as file:
    lines = [line.rstrip().lower() for line in file]

In [3]:
import random
max_seq_len = 32
train_data = []
for _ in range(8):
  i = 0
  seq_length = random.randint(2, max_seq_len-1)
  while i+seq_length < len(lines):
    tokens = lines[i:i+seq_length]
    ordered = " ".join(tokens)
    random.shuffle(tokens)
    unordered = " ".join(tokens)
    train_data.append((unordered, ordered))

    i = i+seq_length
    seq_length = random.randint(2, max_seq_len-1)


num_examples = 60000
for i in range(num_examples):
    # randomise sequence length
    seq_length = random.randint(0, max_seq_len-1)
    indices = random.sample(range(len(lines)), seq_length)
    tokens = [lines[i] for i in sorted(indices)]
    ordered = " ".join(tokens)
    random.shuffle(tokens)
    unordered = " ".join(tokens)
    train_data.append((unordered, ordered))


random.shuffle(train_data)

In [4]:
len(train_data)

174988

In [5]:

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = iter(lines)
tokenizer = get_tokenizer('basic_english')

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

vocab = build_vocab_from_iterator(map(tokenizer, train_iter),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)
vocab.set_default_index(vocab['<unk>']) 

train_iter = iter(train_data)

## Seq2Seq Network using Transformer



In [6]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network 
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

During training, we need a subsequent word mask that will prevent model to look into
the future words when making predictions. We will also need masks to hide
source and target padding tokens. The function below takes care of that.


In [7]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Now we define the parameters of our model and instantiate the same. We also 
define our loss function which is the cross-entropy loss and the optmizer used for training.




In [8]:
torch.manual_seed(0)

VOCAB_SIZE = len(vocab)
EMB_SIZE = 240
NHEAD = 8
FFN_HID_DIM = 265
BATCH_SIZE = 64
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, VOCAB_SIZE, VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0005, betas=(0.9, 0.98), eps=1e-9)

## Collation

Our data iterator yields a pair of raw strings. 
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network. Now we define our collate function that converts batch of raw strings into batch tensors that
can be fed directly into our model.   




In [9]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([BOS_IDX], dtype=torch.long), 
                      torch.tensor(token_ids, dtype=torch.long), 
                      torch.tensor([EOS_IDX], dtype=torch.long)))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in ["input", "target"]:
    text_transform[ln] = sequential_transforms(tokenizer, #Tokenization
                                               vocab, #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform["input"](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform["target"](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

## Training 

Define training and evaluation loop that will be called for each 
epoch.




In [10]:
class Data:
    def __init__(self, t="train"):
        self.t = t

    def __iter__(self):
        return self

    def __next__(self):
      if self.t == "train":
        for d in train_data[:-500]:
          yield d
      elif self.t == "val":
        for d in train_data[-500:]:
          yield d

    def __getitem__(self, item):
      if self.t == "train":
        return  train_data[item]
      if self.t == "val":
        return train_data[500+item]

    def __len__(self):
      if self.t == "train":
        return len(train_data) - 500

      if self.t == "val":
        return 500


In [11]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Data("train")
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    
    for src, tgt in tqdm(train_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Data("val")
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

Start training

In [12]:
best_val_loss = float('inf')
best_model_params_path = "best_model_params.pt"


def do_train():
    for epoch in range(1, NUM_EPOCHS+1):
        start_time = timer()
        print(f"starting epoch {epoch}")
        train_loss = train_epoch(transformer, optimizer)
        end_time = timer()
        val_loss = evaluate(transformer)
        print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(transformer.state_dict(), best_model_params_path)


In [13]:
transformer.load_state_dict(torch.load(best_model_params_path)) # load best model states

<All keys matched successfully>

In [14]:
from timeit import default_timer as timer
NUM_EPOCHS = 3


do_train()


starting epoch 1


  0%|          | 0/2727 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Inference

In [None]:

# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()


        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys



# actual function to order input words by infering the model
def order(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform["input"](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=32, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab.lookup_tokens(list(tgt_tokens.cpu().numpy())))#.replace("<bos>", "").replace("<eos>", "")

In [None]:
print(train_data[3][0])
print(train_data[3][1])

print(order(transformer, train_data[3][0]))