In [None]:
BATCH_SIZE = 256
EPOCHS = 25

In [None]:
import torch
from torch.utils.data import Dataset

class SpellingDataset(Dataset):
  def __init__(self):
    self.raw_data = open("spelling.txt", "r").readlines()
    self.raw_dataset = []

    for line in self.raw_data:
      self.create_raw_examples(line)

  def create_raw_examples(self, line):
    split_line = line.strip().split(" ")
    correct = split_line[0].replace(":", "")
    self.raw_dataset.append({"src": correct, "trg": correct})
    for data in split_line[1:]:
      self.raw_dataset.append({"src": data, "trg": correct})

  def __len__(self):
    return len(self.raw_dataset)

  def __getitem__(self, index):
    example = self.raw_dataset[index]
    return example["src"], example["trg"]

In [None]:
class CharTokenizer:
  def __init__(self):
    self.alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

    self.pad_token = "<pad>"
    self.pad_token_id = 0

    self.sos_token = "<s>"
    self.sos_token_id = 1

    self.eos_token = "<e>"
    self.eos_token_id = 2

    self.unk_token = "<u>"
    self.unk_token_id = 3

    self.char_to_idx = self.create_char_to_idx()
    self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}

  def vocab_size(self):
    return len(self.char_to_idx.keys())

  def create_char_to_idx(self):
    char_dict = {c: i + 4 for i, c in enumerate(self.alphabet)}
    char_dict[self.pad_token] = self.pad_token_id
    char_dict[self.sos_token] = self.sos_token_id
    char_dict[self.eos_token] = self.eos_token_id
    char_dict[self.unk_token] = self.unk_token_id

    return char_dict

  def attempt_encode(self, c):
    if c in self.alphabet:
      return self.char_to_idx[c]
    else:
      return self.unk_token_id

  def encode_sequence(self, seq):
    encoded = [self.attempt_encode(c) for c in seq]
    return [self.sos_token_id] + encoded + [self.eos_token_id]

  def decode_sequence(self, seq, ignore_special_tokens=True):
    decoded = [self.idx_to_char[i] for i in seq]
    decoded = "".join(decoded)
    if ignore_special_tokens:
      decoded = decoded.replace(self.pad_token, "")
      decoded = decoded.replace(self.sos_token, "")
      decoded = decoded.replace(self.eos_token, "")
      decoded = decoded.replace(self.unk_token, "")
    return decoded


In [None]:
import torch.nn.functional as F

tokenizer = CharTokenizer()

def collate(batch):
  srcs = [tokenizer.encode_sequence(ex[0]) for ex in batch]
  src_tensors = [torch.LongTensor(s) for s in srcs]

  trgs = [tokenizer.encode_sequence(ex[1]) for ex in batch]
  trg_tensors = [torch.LongTensor(t) for t in trgs]

  max_src_length = len(max(srcs, key=lambda i: len(i)))
  max_trg_length = len(max(trgs, key=lambda i: len(i)))

  padded_srcs = [F.pad(s, (0, max_src_length - s.shape[0])) for s in src_tensors]
  padded_trgs = [F.pad(t, (0, max_trg_length - t.shape[0])) for t in trg_tensors]

  src = torch.stack(padded_srcs, dim=1)
  trg = torch.stack(padded_trgs, dim=1)

  return {"src": src, "trg": trg}

In [None]:
from torch.utils.data import random_split, DataLoader

ds = SpellingDataset()
generator = torch.Generator().manual_seed(42)
train_ds, val_ds, test_ds = random_split(ds, [0.8, 0.1, 0.1])

train_dl = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=collate)

In [None]:
import torch.nn as nn
from torch.nn import Transformer
import math

class PositionalEncoding(nn.Module):
  def __init__(self, embedding_dim, dropout, maxlen=500):
    super().__init__()
    den = torch.exp(-torch.arange(0, embedding_dim, 2)*math.log(10000) / embedding_dim)
    pos = torch.arange(0, maxlen).reshape(maxlen, 1)
    pos_embedding = torch.zeros((maxlen, embedding_dim))
    pos_embedding[:, 0::2] = torch.sin(pos * den)
    pos_embedding[:, 1::2] = torch.cos(pos * den)
    pos_embedding = pos_embedding.unsqueeze(-2)

    self.dropout = nn.Dropout(dropout)
    self.register_buffer("pos_embedding", pos_embedding)

  def forward(self, token_embedding):
    return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

In [None]:
class TokenEmbedding(nn.Module):
  def __init__(self, vocab_size, embed_dim):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.embed_dim = embed_dim

  def forward(self, tokens):
    return self.embedding(tokens * math.sqrt(self.embed_dim))

In [None]:
class SeqToSeqTransformer(nn.Module):
  def __init__(self,
               num_encoder_layers,
               num_decoder_layers,
               embedding_dim,
               n_heads,
               src_vocab_size,
               trg_vocab_size,
               dim_feedforward=512,
               dropout=0.1):
    super().__init__()
    self.transformer = Transformer(
        d_model=embedding_dim,
        nhead=n_heads,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dim_feedforward=dim_feedforward,
        dropout=dropout)
    self.generator = nn.Linear(embedding_dim, trg_vocab_size)
    self.src_tok_emb = nn.Embedding(src_vocab_size, embedding_dim)
    self.trg_tok_emb = nn.Embedding(trg_vocab_size, embedding_dim)
    self.positional_encoding = PositionalEncoding(embedding_dim, dropout=dropout)

  def forward(self,
              src,
              trg,
              src_mask,
              trg_mask,
              src_padding_mask,
              trg_padding_mask,
              memory_key_padding_mask):
    src_embed = self.positional_encoding(self.src_tok_emb(src))
    trg_embed = self.positional_encoding(self.trg_tok_emb(trg))
    outs = self.transformer(
        src_embed,
        trg_embed,
        src_mask,
        trg_mask,
        None,
        src_padding_mask,
        trg_padding_mask,
        memory_key_padding_mask
    )
    return self.generator(outs)

  def encode(self, src, src_mask):
      return self.transformer.encoder(self.positional_encoding(
                          self.src_tok_emb(src)), src_mask)

  def decode(self, tgt, memory, tgt_mask):
      return self.transformer.decoder(self.positional_encoding(
                        self.trg_tok_emb(tgt)), memory,
                        tgt_mask)

In [None]:
def generate_square_subsequent_mask(size):
  mask = (torch.triu(torch.ones((size, size), device="cuda")) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask

def create_mask(src, trg):
  src_seq_len = src.shape[0]
  trg_seq_len = trg.shape[0]

  trg_mask = generate_square_subsequent_mask(trg_seq_len)
  src_mask = torch.zeros((src_seq_len, src_seq_len), device="cuda").type(torch.bool)

  src_padding_mask = (src == tokenizer.pad_token_id).transpose(0, 1)
  trg_padding_mask = (trg == tokenizer.pad_token_id).transpose(0, 1)

  return src_mask, trg_mask, src_padding_mask, trg_padding_mask

In [None]:
torch.manual_seed(42)

EMBED_DIM = 512
NHEAD = 8
FFN_HIDDEN_DIM = 512
N_ENCODER_LAYERS = 3
N_DECODER_LAYERS = 3

transformer = SeqToSeqTransformer(
    N_ENCODER_LAYERS,
    N_DECODER_LAYERS,
    EMBED_DIM,
    NHEAD,
    tokenizer.vocab_size(),
    tokenizer.vocab_size(),
    FFN_HIDDEN_DIM,
)

transformer = transformer.to("cuda")

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=3e-4)

In [None]:
from tqdm.auto import tqdm

def train_epoch(model, optimizer):
  progress = tqdm(range(len(train_dl)))

  model.train()
  losses = 0
  for batch in train_dl:
    src = batch["src"].to("cuda")
    trg = batch["trg"].to("cuda")

    trg_input = trg[:-1,:]
    src_mask, trg_mask, src_padding_mask, trg_padding_mask = create_mask(src, trg_input)

    logits = model(src, trg_input, src_mask, trg_mask, src_padding_mask, trg_padding_mask, src_padding_mask)
    optimizer.zero_grad()

    trg_out = trg[1:, :]
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), trg_out.reshape(-1))
    loss.backward()

    optimizer.step()
    losses += loss.item()

    progress.update(1)

  return losses / len(list(train_dl))

In [None]:
def evaluate(model):
  model.eval()
  losses = 0

  for batch in val_dl:
    src = batch["src"].to("cuda")
    trg = batch["trg"].to("cuda")

    trg_input = trg[:-1,:]
    src_mask, trg_mask, src_padding_mask, trg_padding_mask = create_mask(src, trg_input)

    logits = model(src, trg_input, src_mask, trg_mask, src_padding_mask, trg_padding_mask, src_padding_mask)

    trg_out = trg[1:, :]
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), trg_out.reshape(-1))
    losses += loss.item()

  return losses / len(list(val_dl))

In [None]:
for epoch in range(EPOCHS):
  print("training epoch", epoch + 1)
  train_loss = train_epoch(transformer, optimizer)
  eval_loss = evaluate(transformer)
  print(f"Train loss: {train_loss:.3f}, Val loss: {eval_loss:.3f}")

training epoch 1


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 1.562, Val loss: 0.822
training epoch 2


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.754, Val loss: 0.565
training epoch 3


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.587, Val loss: 0.456
training epoch 4


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.499, Val loss: 0.397
training epoch 5


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.438, Val loss: 0.353
training epoch 6


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.396, Val loss: 0.334
training epoch 7


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.362, Val loss: 0.306
training epoch 8


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.336, Val loss: 0.299
training epoch 9


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.317, Val loss: 0.287
training epoch 10


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.297, Val loss: 0.275
training epoch 11


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.283, Val loss: 0.275
training epoch 12


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.270, Val loss: 0.272
training epoch 13


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.258, Val loss: 0.257
training epoch 14


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.246, Val loss: 0.264
training epoch 15


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.234, Val loss: 0.250
training epoch 16


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.229, Val loss: 0.254
training epoch 17


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.223, Val loss: 0.250
training epoch 18


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.215, Val loss: 0.248
training epoch 19


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.208, Val loss: 0.246
training epoch 20


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.199, Val loss: 0.245
training epoch 21


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.192, Val loss: 0.242
training epoch 22


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.188, Val loss: 0.244
training epoch 23


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.186, Val loss: 0.246
training epoch 24


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.181, Val loss: 0.240
training epoch 25


  0%|          | 0/133 [00:00<?, ?it/s]

Train loss: 0.172, Val loss: 0.236


In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
  src = src.to("cuda")
  src_mask = src_mask.to("cuda")

  memory = model.encode(src, src_mask)
  ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to("cuda")
  for i in range(max_len-1):
    memory = memory.to("cuda")
    tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool).to("cuda")
    out = model.decode(ys, memory, tgt_mask)
    out = out.transpose(0, 1)
    prob = model.generator(out[:, -1])
    _, next_word = torch.max(prob, dim=1)
    next_word = next_word.item()

    ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
    if next_word == tokenizer.eos_token_id:
      break

  return ys

def translate(model, src_sentence):
  model.eval()
  src = collate([(src_sentence, "")])["src"]
  n_tokens = src.shape[0]
  src_mask = torch.zeros(n_tokens, n_tokens).type(torch.bool)
  trg_tokens = greedy_decode(model, src, src_mask, max_len=n_tokens+5, start_symbol=tokenizer.sos_token_id).flatten()
  return tokenizer.decode_sequence(list(trg_tokens.cpu().numpy()))

In [None]:
transformer.eval()
n_guessed_correctly = 0
n_total = 0

for batch in val_dl:
  src = batch["src"]
  trg = batch["trg"]

  for i in range(src.shape[1]):
    src_example = src[: , i]
    src_list = list(src_example.numpy())
    src_example = tokenizer.decode_sequence(src_list)
    
    gold = trg[:, i]
    gold_list = list(gold.numpy())
    gold = tokenizer.decode_sequence(gold_list)

    sys = translate(transformer, src_example)
    if sys == gold:
      n_guessed_correctly += 1
    n_total += 1

    if i == 0:
      print("SRC:", src_example)
      print("SYS:", sys)
      print("GOLD:", gold)
      print()

print("Dev accuracy:", n_guessed_correctly / n_total)

SRC: ledser
SYS: ledser
GOLD: leisure

SRC: simpathetic
SYS: sympathetic
GOLD: sympathetic

SRC: amedy
SYS: annum
GOLD: immediately

SRC: dusic
SYS: dusic
GOLD: juice

SRC: possibly
SYS: possibly
GOLD: possible

SRC: miscellaneaises
SYS: miscellaneous
GOLD: miscellaneous

SRC: intelictual
SYS: intellectual
GOLD: intellectual

SRC: attach
SYS: attaching
GOLD: attack

SRC: rheumtizeum
SYS: rheumatism
GOLD: rheumatism

SRC: leeds
SYS: leeds
GOLD: leeds

SRC: went
SYS: went
GOLD: with

SRC: adequante
SYS: adequante
GOLD: adequate

SRC: trolly
SYS: truly
GOLD: trolley

SRC: sent
SYS: sent
GOLD: cent

