<a href="https://colab.research.google.com/github/edwardcdy/deep-learning-notebooks/blob/main/Machine_Translation_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, Tensor
from torch.utils.data import dataset, DataLoader
import torch.nn.functional as F

from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [None]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

In [None]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from typing import Iterable, List


SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}


# Create source and target language tokenizer. Make sure to install the dependencies.
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [None]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [None]:
BATCH_SIZE = 32

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# Simple feedforward model

In [None]:
class FF(nn.Module):
  
  def __init__(self, input_vocab_size: int, target_vocab_size: int, max_seq_len: int = 30):
    super().__init__()
    self.lookup = nn.Embedding(input_vocab_size, 10)
    self.fc1 = nn.Linear(max_seq_len * 10, max_seq_len * 20)
    self.fc2 = nn.Linear(max_seq_len * 20, max_seq_len * 20)
    self.fc3 = nn.Linear(max_seq_len * 20, max_seq_len * target_vocab_size)

  def forward(self, x):
    batch_size, max_seq_len = x.shape
    embedded = self.lookup(x)
    x = embedded.reshape(batch_size, max_seq_len * 10)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x).reshape((batch_size, max_seq_len, -1))
    return F.log_softmax(x, dim=2)


In [None]:
in_vocab = len(vocab_transform[SRC_LANGUAGE])
out_vocab = len(vocab_transform[TGT_LANGUAGE])

net = FF(in_vocab, out_vocab)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def train(net: nn.Module, epochs: int = 20):
  net.train()
  opt = torch.optim.Adam(net.parameters(), lr=3e-4)
  loss = torch.nn.NLLLoss()

  for _ in range(epochs):
    for x,y in train_dataloader:
      net.zero_grad()

      x, y = scale_transpose_seqs(x, y)

      out = net(x)

      l = loss(out.swapaxes(1,2), y)
      l.backward()
      opt.step()


def scale_transpose_seqs(x: torch.Tensor, y: torch.Tensor, max_length: int = 30):
  if x.shape[0] < max_length:
    add = torch.full((max_length - x.shape[0], x.shape[1]), PAD_IDX)
    x = torch.cat((x, add), dim=0)
  elif x.shape[0] > max_length:
    x = x[:max_length, :]

  if y.shape[0] < max_length:
    add = torch.full((max_length - y.shape[0], y.shape[1]), PAD_IDX)
    y = torch.cat((y, add), dim=0)
  elif y.shape[0] > max_length:
    y = y[:max_length, :]

  return x.to(device).long().T, y.to(device).long().T

In [None]:
from torchtext.data.functional import to_map_style_dataset

BATCH_SIZE = 1

valid_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
valid_dataloader = DataLoader(to_map_style_dataset(valid_iter), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)


def example_translations(net: nn.Module, n_samples: int): 
  net.eval()

  english = vocab_transform[TGT_LANGUAGE].get_itos()
  german = vocab_transform[SRC_LANGUAGE].get_itos()

  count = 0
  for x,y in valid_dataloader:
    print(f'Source sentence: {" ".join(map(lambda a: german[a], x))}')

    x, _ = scale_transpose_seqs(x,y)
    out = net(x)

    print(f'Translated sentence: {" ".join(map(lambda a: english[a.argmax()], out.squeeze()))}')

    count += 1
    if count == n_samples:
      break


In [None]:
example_translations(net, 1)

Source sentence: <bos> Zwei Menschen rennen auf dem Gipfel eines Berges . <eos>
Translated sentence: mounted natured popsicles learns Muzzled Half armor Jacket Worker lioness Ten mowing Inn Medieval wits tricks folk 94 Tournament 145 presume marching peoples effort awnings Room terrorizes toad returning resourceful
