# NMT Example --- Traditional ML Approach

Adapted from https://www.tensorflow.org/tutorials/text/nmt_with_attention

## Set Up Environment

In [2]:
!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
!pip install --upgrade jax
!pip install git+https://github.com/deepmind/dm-haiku

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import jax
import jax.numpy as jnp
from jax.experimental import optimizers

import haiku as hk

import unicodedata
import re
import numpy as np
import os
import io
import time

Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)
Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-n8f8jxjj
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-n8f8jxjj
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=289739 sha256=31a1f3bf7c0bc62f063c1630283257dac52b38679e60d2ef754b5cf2192cf32c
  Stored in directory: /tmp/pip-ephem-wheel-cache-0la00c1v/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499
Successfully built dm-haiku
Installing collected packages: dm-haiku
Successfully installed dm-haiku-0.0.2


## Dataset Processing & NLP-specific Functions

In [3]:
# Download the file
path_to_zip = tf.keras.utils.get_file(
    'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
    extract=True)

path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"

Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip


In [4]:
# ========= DATA PROCESSING =============
# Converts the unicode file to ascii
def unicode_to_ascii(s):
  return ''.join(c for c in unicodedata.normalize('NFD', s)
      if unicodedata.category(c) != 'Mn')


def preprocess_sentence(w):
  w = unicode_to_ascii(w.lower().strip())

  # creating a space between a word and the punctuation following it
  # eg: "he is a boy." => "he is a boy ."
  # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation
  w = re.sub(r"([?.!,¿])", r" \1 ", w)
  w = re.sub(r'[" "]+', " ", w)

  # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")
  w = re.sub(r"[^a-zA-Z?.!,¿]+", " ", w)

  w = w.strip()

  # adding a start and an end token to the sentence
  # so that the model know when to start and stop predicting.
  w = '<start> ' + w + ' <end>'
  return w

# 1. Remove the accents
# 2. Clean the sentences
# 3. Return word pairs in the format: [ENGLISH, SPANISH]
def create_dataset(path, num_examples):
  lines = io.open(path, encoding='UTF-8').read().strip().split('\n')

  word_pairs = [[preprocess_sentence(w) for w in l.split('\t')]  for l in lines[:num_examples]]

  return zip(*word_pairs)

def tokenize(lang):
  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(
      filters='')
  lang_tokenizer.fit_on_texts(lang)

  tensor = lang_tokenizer.texts_to_sequences(lang)

  tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,
                                                         padding='post')

  return tensor, lang_tokenizer

def load_dataset(path, num_examples=None):
  # creating cleaned input, output pairs
  targ_lang, inp_lang = create_dataset(path, num_examples)

  input_tensor, inp_lang_tokenizer = tokenize(inp_lang)
  target_tensor, targ_lang_tokenizer = tokenize(targ_lang)

  return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer

In [52]:
# Try experimenting with the size of that dataset
num_examples = 30000
# num_examples = -1
input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)

# Calculate max_length of the target tensors
max_length_targ, max_length_inp = target_tensor.shape[1], input_tensor.shape[1]

# Creating training and validation sets using an 80-20 split
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)

# Show length
print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))

24000 24000 6000 6000


In [53]:
#make the dataset
BUFFER_SIZE = len(input_tensor_train)
BATCH_SIZE = 64
steps_per_epoch = len(input_tensor_train)//BATCH_SIZE
embedding_dim = 256
units = 1024
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1

dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

import tensorflow_datasets as tfds
dataset = tfds.as_numpy(dataset)

example_input_batch, example_target_batch = next(iter(dataset))
example_input_batch.shape, example_target_batch.shape

((64, 16), (64, 11))

## Define the Encoder-Decoder Model

This is a standard encoder-decoder architecture with attentional decoding. See the paper https://arxiv.org/pdf/1409.0473.pdf for details. The attention mechanism allows the model to selectively *attend* to the encoded inputs, allowing the model to focus on the most important inputs in the source language for each prediction in the target language.

We use a GRU-based recurrent model with scaled dot-product attention (which is different from the paper above). 

In [47]:
class Encoder(hk.Module):
  def __init__(self, vocab_size, d_model):
    super(Encoder, self).__init__()
    #is it better to keep the embedding outside?
    self.embedding = hk.Embed(vocab_size=vocab_size, embed_dim=d_model)
    self.gru = hk.GRU(hidden_size=d_model, 
                      w_i_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                      w_h_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                      b_init=hk.initializers.Constant(0.0))
  
  def initial_state(self, batch_size):
    return self.gru.initial_state(batch_size)
  
  def __call__(self, tokens, init_state):
    inputs = self.embedding(tokens)
    return hk.dynamic_unroll(self.gru, inputs, init_state)

class ScaledDotAttention(hk.Module):
  """ Implements single-headed scaled dot-product attention """ 
  def __init__(self, d_model):
    super(ScaledDotAttention, self).__init__()
    self.W_Q = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                                  b_init = hk.initializers.Constant(0.0))
    self.W_K = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                                  b_init = hk.initializers.Constant(0.0))
    self.W_V = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                                  b_init = hk.initializers.Constant(0.0))
    self.d_model = d_model
    self.root_d_model = np.sqrt(self.d_model)
  
  def __call__(self, Q, K, V):
    #apply linear projections to the Queries, Keys, and Values
    Q = self.W_Q(Q)
    K = self.W_K(K)
    V = self.W_V(V)

    #batch-dimension last...this is weird
    scores = jnp.einsum('...bd,tbd->...tb', Q, K)/self.root_d_model

    #normalize the scores
    probs = jax.nn.softmax(scores, axis=-2)
    
    #average the values w.r.t. the probs
    return  jnp.einsum('...tb,tbd->...bd', probs, V)
    
class BhadanauAttention(hk.Module):
  def __init__(self, d_model):
    super(BhadanauAttention, self).__init__()
    self.W_Q = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                                  b_init = hk.initializers.Constant(0.0))
    self.W_K = hk.Linear(d_model, w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                                  b_init = hk.initializers.Constant(0.0))
    self.W_score = hk.Linear(1, w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                                  b_init = hk.initializers.Constant(0.0))

  def __call__(self, Q, K, V):
    Q = jnp.expand_dims(Q, 0)
    #project the inputs
    Q = self.W_Q(Q)
    K = self.W_K(K)

    # compute the scores using the Bhadanau attention mechanism
    scores = self.W_score(jnp.tanh(Q + K))

    # normalize the scores into probs
    probs = jax.nn.softmax(scores, axis=0) #0 is time axis

    # average the values w.r.t. the probs
    return jnp.einsum('tbd,tbd->bd', probs, V)

class Decoder(hk.Module):
  def __init__(self, attn, vocab_size, d_model):
    super(Decoder, self).__init__()
    self.embedding = hk.Embed(vocab_size=vocab_size, embed_dim=d_model)
    self.attn = attn
    self.gru = hk.GRU(hidden_size=d_model, 
                      w_i_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                      w_h_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform"),
                      b_init=hk.initializers.Constant(0.0))
    
    self.proj = hk.Linear(vocab_size)
  
  def initial_state(self, batch_size):
    return self.gru.initial_state(batch_size)

  def __call__(self, tokens, enc_outputs, hidden_state):
    """ do attention with queries = hidden state, keys = enc_outputs, 
        values = enc_outputs to select the most 'relevant' encoded outputs 
        to the hidden state."""
    
    # hidden_state = np.expand_dims(hidden_state, 0)
    ctx_vector = self.attn(hidden_state, enc_outputs, enc_outputs)

    # embed the tokens with the target embedding
    inputs = self.embedding(tokens)

    # concat the ctx_vector to the embeddings
    inputs = jnp.concatenate([ctx_vector, inputs], axis=-1)

    #apply the decoder to the context + inputs
    outputs, hidden_state = self.gru(inputs, hidden_state)

    # project outputs into logit space and return (logits, hidden_state)
    return self.proj(outputs), hidden_state
  

## Define the Encoder and Decoder 'Forward' functions

We define these separately since we need to run the encoder once and the decoder multiple times for autoregressive decoding.

In [48]:
def encoder_fn(input_seqs):
  """ assumes input_seqs is time-first
      args:
        input_seqs: an input sequence of tokens
      
      returns:
         a tuple of arrays for the the encoded outputs and the final hidden state of the encoder 
  """
  
  encoder = Encoder(vocab_size=vocab_inp_size, d_model = embedding_dim)
  batch_size = input_seqs.shape[1]

  #initialize the hidden state
  enc_initial_state = encoder.initial_state(batch_size)
  
  #apply the encoder to the full sequence using hk.dynamic_unroll(...)
  enc_outputs, enc_hidden = encoder(input_seqs, enc_initial_state)

  return enc_outputs, enc_hidden

def decoder_fn(dec_inputs, hidden_state, enc_outputs):
  """ assumes dec_inputs are time-first """
  attn = ScaledDotAttention(d_model = embedding_dim)
  # attn = BhadanauAttention(d_model = embedding_dim) # uncomment for Bhadanau attention
  
  decoder = Decoder(attn, vocab_size = vocab_tar_size, d_model = embedding_dim)

  # apply the decoder to a single input (i.e. not unrolled) since we need 
  # to autoregressively generate the translation.
  outputs, hidden_state = decoder(dec_inputs, enc_outputs, hidden_state)

  return outputs, hidden_state

def init_params(key, batch):
  test_inputs, test_targets = batch

  #transpose inputs to be time-first
  test_inputs = test_inputs.transpose(1,0)
  test_targets = test_targets.transpose(1,0)

  encoder = hk.transform(encoder_fn, apply_rng = True)
  enc_params = encoder.init(jax.random.PRNGKey(42), test_inputs)
  enc_outputs, enc_hiddens = encoder.apply(enc_params, jax.random.PRNGKey(0), test_inputs)

  decoder = hk.transform(decoder_fn, apply_rng = True)
  dec_params = decoder.init(jax.random.PRNGKey(42), test_targets[0], 
                            enc_hiddens, enc_outputs)

  return enc_params, dec_params

## Define the Loss Function and Train Step

In [49]:
lr = 1e-3
opt_init, opt_update, opt_get_params = optimizers.adam(lr)

def masked_crossent(logits, targets):
  one_hot_labels = jax.nn.one_hot(targets, vocab_tar_size)

  # since we have padded the batch with 0s, to make them uniform length, 
  # we need to mask out the padding tokens, which are index-zero tokens
  mask = jnp.expand_dims(targets > 0,1)

  #do masked mean, ensuring that length-zero batches don't give nan.
  denom = jnp.max(jnp.array([jnp.sum(mask), 1]))
  crossent = - jnp.sum(one_hot_labels * jax.nn.log_softmax(logits) * mask) / denom

  return crossent

def loss(params, batch):
  enc_params, dec_params = params

  input_batch, target_batch = batch

  #transpose batch to be time-first
  input_batch = jnp.transpose(input_batch, (1,0))
  target_batch = jnp.transpose(target_batch, (1,0))

  #encode the batch once
  enc_outputs, enc_hidden = encoder.apply(enc_params, jax.random.PRNGKey(0), input_batch)

  #initalize the decoder's hidden state to be the encoder's hidden state
  dec_hidden = enc_hidden

  #start predicting with the <start> token
  dec_input = jnp.array([targ_lang.word_index['<start>']] * BATCH_SIZE)

  t_max = target_batch.shape[0]
  loss = 0.0
  for t in range(1, t_max):
    # iterate through the targets
    targets = target_batch[t]

    # compute logits over target vocabulary for the current word (targets)
    logits, dec_hidden = decoder.apply(dec_params, jax.random.PRNGKey(0), 
                                       dec_input, dec_hidden, enc_outputs)

    # accumulate the loss
    loss += masked_crossent(logits, targets)
    
    # use teacher forcing by providing the ground-truth input to the model at each timestep
    dec_input = targets

  return loss / t_max


@jax.jit
def train_step(i, opt_state, batch):
    params = opt_get_params(opt_state)
    # batch_loss_fn = lambda p: loss(p, batch)
    # fx, dx = jax.value_and_grad(batch_loss_fn)(params)
    fx, dx = jax.value_and_grad(loss)(params, batch)
    opt_state = opt_update(i, dx, opt_state)
    return fx, opt_state


def eval_step(params, sentence, max_len=32):
  """ decodes a single input sentence, provided as a string """
  enc_params, dec_params = params

  # tokenize input string
  sentence = preprocess_sentence(sentence)
  inputs = [inp_lang.word_index[token] for token in sentence.split(' ')]
  inputs = np.expand_dims(jnp.array(inputs), 1)

  # encode the inputs
  enc_outputs, enc_hidden = encoder.apply(enc_params, jax.random.PRNGKey(0), inputs)

  # initialize the decoder's hidden state with the encoder's hidden state
  dec_hidden = enc_hidden

  #start predicting with the <start> token
  dec_input = jnp.array([targ_lang.word_index['<start>']] * 1)

  result = []
  for t in range(1, max_len):
    # compute the logits for the current token
    logits, dec_hidden = decoder.apply(dec_params, jax.random.PRNGKey(0), 
                                       dec_input, dec_hidden, enc_outputs)

    # greedy-decode the prediction
    pred_idx = int(jnp.argmax(logits))
    result.append(targ_lang.index_word[pred_idx])

    #if the decoder says 'stop', return
    if targ_lang.index_word[pred_idx] == '<end>':
      break
    
    #otherwise, the prediction becomes the input (for autogregressive decoding)
    dec_input = jnp.array([pred_idx])
  
  return " ".join(result) + '.'

## Do the training

In [54]:
init_key = jax.random.PRNGKey(0)
params = init_params(init_key, next(dataset))
opt_state = opt_init(params)

train_dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train))\
                         .shuffle(BUFFER_SIZE)

for epoch in range(10):
  epoch_loss = 0.0
  dataset_iter = tfds.as_numpy(train_dataset.batch(BATCH_SIZE, drop_remainder=True))
  
  start = time.time()
  for b, batch in enumerate(dataset_iter):
    train_loss, opt_state = train_step(b, opt_state, batch)
    epoch_loss += train_loss
  
  print(f"epoch = {epoch}",
        f" | train loss = {epoch_loss / (b + 1):.5f}",
        f" | time per epoch = {time.time() - start:.2f}s")

epoch = 0  | train loss = 1.84770  | time per epoch = 13.29s
epoch = 1  | train loss = 1.17798  | time per epoch = 5.35s
epoch = 2  | train loss = 0.92113  | time per epoch = 5.38s
epoch = 3  | train loss = 0.74348  | time per epoch = 5.38s
epoch = 4  | train loss = 0.60680  | time per epoch = 5.41s
epoch = 5  | train loss = 0.50736  | time per epoch = 5.46s
epoch = 6  | train loss = 0.42633  | time per epoch = 5.46s
epoch = 7  | train loss = 0.34549  | time per epoch = 5.42s
epoch = 8  | train loss = 0.28708  | time per epoch = 5.41s
epoch = 9  | train loss = 0.24023  | time per epoch = 5.40s


## Evaluate on some sample sentences

Note: this is a simple model trained on a subset of the data. The translations are not perfect (below are some reasonable outputs).

In [63]:
params = opt_get_params(opt_state)
print(eval_step(params, u'hace mucho calor aqui.'))
print(eval_step(params, u'hola!'))
print(eval_step(params, u'¿cómo estás?'))

it s hot here . <end>.
hello ! <end>.
how are you ? <end>.
