<a href="https://colab.research.google.com/github/gauravjain14/All-about-JAX/blob/main/Building_a_Seq2Seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# What this Notebook is about?


This notebook is based on the [Seq2Seq](https://github.com/google/flax/tree/main/examples/seq2seq) notebook example available from [Google's Flax](https://github.com/google/flax) repo.


The overall agenda of the notebook is to train a simple LSTM on a sequence-to-sequence addition task using an encoder-decoder architecture. Data is generated on the fly and the overall task is to predict the sum of two numbers (upto 3-digits each).
```
Input Data Format:
"123+456<EOS>"

Output Format:
"=579<EOS>"
```

### Important Learnings
1. To do some old-school print-based debugging, I had disabled @jax.jit decorators but then forgot to re-enable them before I ran my training loops. That was a bad move! The difference in the training time with and without @jax.jit is significant with the former being at least 5x faster.

2. It took me a good amount of time and a lot of code-copying from the original repository to finally get this implementation right. Implementing a model in JAX/Flax is still going to take me a considerable amount of practice. Nevertheless, we'll get there.

In [13]:
!pip install clu
!pip install jax
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [14]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn
import flax
import numpy as np

import functools
from typing import Any, Dict, Tuple

import os
from absl import app
from absl import flags
from absl import logging
from clu import metric_writers
from flax import linen as nn
from flax.training import train_state
import optax

import tensorflow as tf

In [15]:
Array = Any

In [16]:
class EncoderLSTM(nn.Module):
  eos_id: int

  @functools.partial(
      nn.scan,
      variable_broadcast='params',
      in_axes=1,
      out_axes=1,
      split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry: Tuple[Array, Array],
               x: Array) -> Tuple[Tuple[Array, Array], Array]:
    """Applies the module."""
    lstm_state, is_eos = carry
    new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
    
    # Pass forward the previous state if EOS has already been reached.
    def select_carried_state(new_state, old_state):
      return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
    # LSTM state is a tuple (c, h).
    carried_lstm_state = tuple(
      select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
    # Update `is_eos`.
    is_eos = jnp.logical_or(is_eos, x[:, self.eos_id])
    return (carried_lstm_state, is_eos), y

  @staticmethod
  def initialize_carry(batch_size: int, hidden_size: int):
    return nn.LSTMCell.initialize_carry(
        jax.random.PRNGKey(0), (batch_size,), hidden_size)


In [17]:
class Encoder(nn.Module):
  """LSTM encoder, returning state after finding the EOS token in the input."""
  hidden_size: int
  eos_id: int

  @nn.compact
  def __call__(self, inputs: Array):
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size = inputs.shape[0]
    lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id)
    init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
    # We use the `is_eos` array to determine whether the encoder should carry
    # over the last lstm state, or apply the LSTM cell on the previous state.
    init_is_eos = jnp.zeros(batch_size, dtype=bool)
    init_carry = (init_lstm_state, init_is_eos)

    # Why don't we need the output?
    (final_state, _), _ = lstm(init_carry, inputs)
    return final_state

In [18]:
class DecoderLSTM(nn.Module):
  """DecoderLSTM Module wrapped in a lifted scan transform.
  Attributes:
    teacher_force: See docstring on Seq2seq module.
    vocab_size: Size of the vocabulary.
  """
  teacher_force: bool
  vocab_size: int

  @functools.partial(
      nn.scan,
      variable_broadcast='params',
      in_axes=1,
      out_axes=1,
      split_rngs={'params': False, 'lstm': True})
  @nn.compact
  def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
    """Applies the DecoderLSTM model."""
    lstm_state, last_prediction = carry
    if not self.teacher_force:
      x = last_prediction
    lstm_state, y = nn.LSTMCell()(lstm_state, x)
    logits = nn.Dense(features=self.vocab_size)(y)
    # Sample the predicted token using a categorical distribution over the
    # logits.
    categorical_rng = self.make_rng('lstm')
    predicted_token = jax.random.categorical(categorical_rng, logits)
    # Convert to one-hot encoding.
    prediction = jax.nn.one_hot(
        predicted_token, self.vocab_size, dtype=jnp.float32)

    return (lstm_state, prediction), (logits, prediction)


class Decoder(nn.Module):
  """LSTM decoder.
  Attributes:
    init_state: [batch_size, hidden_size]
      Initial state of the decoder (i.e., the final state of the encoder).
    teacher_force: See docstring on Seq2seq module.
    vocab_size: Size of the vocabulary.
  """
  init_state: Tuple[Any]
  teacher_force: bool
  vocab_size: int

  @nn.compact
  def __call__(self, inputs: Array) -> Tuple[Array, Array]:
    """Applies the decoder model.
    Args:
      inputs: [batch_size, max_output_len-1, vocab_size]
        Contains the inputs to the decoder at each time step (only used when not
        using teacher forcing). Since each token at position i is fed as input
        to the decoder at position i+1, the last token is not provided.
    Returns:
      Pair (logits, predictions), which are two arrays of respectively decoded
      logits and predictions (in one hot-encoding format).
    """
    lstm = DecoderLSTM(teacher_force=self.teacher_force,
                       vocab_size=self.vocab_size)
    init_carry = (self.init_state, inputs[:, 0])
    _, (logits, predictions) = lstm(init_carry, inputs)
    return logits, predictions

In [19]:
class Seq2Seq(nn.Module):
  """Sequence-to-sequence class using encoder/decoder architecture.
  Attributes:
    teacher_force: whether to use `decoder_inputs` as input to the decoder at
      every step. If False, only the first input (i.e., the "=" token) is used,
      followed by samples taken from the previous output logits.
    hidden_size: int, the number of hidden dimensions in the encoder and decoder
      LSTMs.
    vocab_size: the size of the vocabulary.
    eos_id: EOS id.
  """
  teacher_force: bool
  hidden_size: int
  vocab_size: int
  eos_id: int = 1

  @nn.compact
  def __call__(self, encoder_inputs: Array,
               decoder_inputs: Array): # -> Tuple[Array, Array]:
    """
    encoder_inputs: [batch_size, max_input_length, vocab_size] padded batch
    of input sequences to encode
    
    decoder_inputs: [batch_size, max_output_length, vocab_size] padded batch
    of expected decoded sequences for teacher forcing."""
    # Encode inputs
    final_encoder_state = Encoder(
        hidden_size=self.hidden_size, eos_id=self.eos_id)(encoder_inputs)
    # Decode outputs
    # decoder_inputs[:, -1] implies from (1, 6, 15), take (1, 5, 15) and
    # use them as inputs.
    # Note: I spent a week trying to find an error about shape mismatch
    # between the context input (c) and LSTM output.
    # A renewed walkthrough of the code revealed I was using
    # decoder_inputs[:, -1] instead of decoder_inputs[:, :-1]. 
    # If you find the error in a jiffy, you deserve a raise!
    logits, predictions = Decoder(
       init_state=final_encoder_state,
       teacher_force=self.teacher_force,
       vocab_size=self.vocab_size)(decoder_inputs[:, :-1])
    return logits, predictions

## Training

In [20]:
!cp drive/MyDrive/Work/Deep\ Learning/JAX/seq2seq/input_pipeline.py .
!pip install clu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [21]:
!cp drive/MyDrive/Work/Deep\ Learning/JAX/seq2seq/input_pipeline.py .
import input_pipeline
from input_pipeline import CharacterTable as CTable
from input_pipeline import get_sequence_lengths
from input_pipeline import mask_sequences

Array = Any
FLAGS = flags.FLAGS
PRNGKey = Any

In [22]:
hidden_size = 512
num_train_steps = 10000
batch_size = 128
learning_rate = 0.003
decode_frequency = 100

In [23]:
def train_and_evaluate(workdir: str, learning_rate: jnp.float32 = 1e-6):
  """ As per comments in the Flax source code
  https://github.com/google/flax/blob/main/examples/seq2seq/input_pipeline.py
  max_len_query_digit defaults to 3. This implies the input cannot exceed
  999. Thus, for computing maximum input sequence length, we can have two
  up to 3-digit numbers and a plus sign and an <eos> character """
  ctable = CTable('0123456789+= ', 3)
  rng = jax.random.PRNGKey(0)

  def init_model_params(rng: PRNGKey):
    """ Using model.init() to initialize weights of the model;
    This uses the same approach as shown in  
    https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html for
    Variable initialization, in a lazy manner such that the params are
    initialized without a forward pass. """
    rng1, rng2 = jax.random.split(rng)
    variables = model.init(
        {'params': rng1, 'lstm': rng2},
        jnp.ones(ctable.encoder_input_shape, jnp.float32),
        jnp.ones(ctable.decoder_input_shape, jnp.float32)
    )
    return variables['params']

  def get_model(ctable: CTable, *, teacher_force: bool = False) -> Seq2Seq:
    return Seq2Seq(teacher_force=teacher_force,
                        hidden_size=hidden_size, eos_id=ctable.eos_id,
                        vocab_size=ctable.vocab_size)

  def cross_entropy_loss(logits: Array, labels: Array,
                         lengths: Array) -> float:
    """ Returns corss-entropy loss """
    cross_entropy = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
    ## Example in the github repository uses some masked cross entropy.
    # Not using that here now; Let's see if we need it.
    # https://github.com/google/flax/blob/07e513f639cfc4a946d1d20cebb4bd2ff4f94a24/examples/seq2seq/train.py#L101
    # Thus, we are also not using the lengths Array
    masked_cross_entropy = jnp.mean(mask_sequences(cross_entropy, lengths))
    return -jnp.mean(cross_entropy)

  ## Defining compute_metrics because the tutorial does. I think we can
  # get away with this repeated function call
  def compute_metrics(logits: Array, labels: Array,
                    eos_id: int) -> Dict[str, float]:
    # If train_step is not tagged with @jax.jit, arrays in this func
    # are treated as device arrays and we can porint their values!!
    # lengths - for the batch input, we can have the output as upto 3-digit
    # sum output.
    lengths = get_sequence_lengths(labels, eos_id)
    loss = cross_entropy_loss(logits, labels, lengths)
    # Computes sequence accuracy, which is the same as the accuracy during
    # inference, since teacher forcing is irrelevant when all output are correct.
    token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1)
    # always find using np.newaxis is interesting.
    sequence_accuracy = (
        jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths
    )
    accuracy = jnp.mean(sequence_accuracy)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics

  def log_decode(question: str, inferred: str, golden: str):
    """Logs the given question, inferred query, and correct query."""
    suffix = '(CORRECT)' if inferred == golden else (f'(INCORRECT) '
                                                    f'correct={golden}')
    print(f'DECODE: {question} = {inferred} {suffix}')


  @functools.partial(jax.jit, static_argnums=3)
  def decode(params: Dict[str, Any], inputs: Array, decode_rng: PRNGKey,
            ctable: CTable) -> Array:
    """Decodes inputs."""
    init_decoder_input = ctable.one_hot(ctable.encode('=')[0:1])
    init_decoder_inputs = jnp.tile(init_decoder_input,
                                  (inputs.shape[0], ctable.max_output_len, 1))
    model = Seq2Seq(teacher_force=False,
                        hidden_size=hidden_size, eos_id=ctable.eos_id,
                        vocab_size=ctable.vocab_size)
    model = get_model(ctable, teacher_force=False)
    _, predictions = model.apply({'params': params},
                                inputs,
                                init_decoder_inputs,
                                rngs={'lstm': decode_rng})
    return predictions    

  def decode_batch(state: train_state.TrainState, batch: Dict[str, Array],
                   decode_rng: PRNGKey, ctable: CTable):
    """ Decodes anbd logs results for a batch """
    inputs, outputs = batch['query'], batch['answer'][:, 1:]
    decode_rng = jax.random.fold_in(decode_rng, state.step)
    inferred = decode(state.params, inputs, decode_rng, ctable)
    questions = ctable.decode_onehot(inputs)
    infers = ctable.decode_onehot(inferred)
    goldens = ctable.decode_onehot(outputs)

    for question, inferred, golden in zip(questions, infers, goldens):
      print("Logging Decode")
      log_decode(question, inferred, golden)

  # Disabling jit for now. I want to keep a track of the variables and do
  # some traditional print-based debugging
  @jax.jit
  def train_step(state: train_state.TrainState, batch: Array,
                 lstm_rng: PRNGKey, eos_id: int) -> Tuple[
                     train_state.TrainState, Dict[str, float]]:
    """ state: State storing the model, params, and updates to the params
      batch: Input batch
      
    """
    labels = batch['answer'][:, 1:]
    lstm_key = jax.random.fold_in(lstm_rng, state.step)

    def loss_fn(params):
      # We are not using the predictions from the Decoder output.
      # Why do the categorical computations at all?
      logits, _ = state.apply_fn(
          # first arg corresponds to variables (from linen.Module apply)
          {'params': params},
          batch['query'],
          batch['answer'], # inputs to Seq2Seq __call__
          rngs={'lstm': lstm_key}) # this looks like metadata as well.
      loss = cross_entropy_loss(logits, labels,
                                get_sequence_lengths(labels, eos_id))
      return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    # Why do we return the loss if it's not being used at all?
    (_, logits), grads = grad_fn(state.params)
    # This does the backprop?
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, labels, eos_id)
    
    return state, metrics

  model = get_model(ctable)
  params = init_model_params(rng)
  tx = optax.adam(learning_rate)

  # model, params, and tx are used to fill the TrainingState
  # (described in flax.training)
  # Side note: create() is a @classmethod. It's essentially Python's way of
  # defining a static method
  state = train_state.TrainState.create(apply_fn=model.apply, params=params,
                                        tx=tx)
  ## Let's train
  for step in range(num_train_steps):
    batch = ctable.get_batch(batch_size)
    state, metrics = train_step(state, batch, rng, ctable.eos_id)
    if step and step % decode_frequency == 0:
      print(f'Accuracy {metrics["accuracy"]} and loss {metrics["loss"]}')
      batch = ctable.get_batch(5)
      decode_batch(state, batch, rng, ctable)
  print(model)

In [None]:
train_and_evaluate(None, learning_rate=learning_rate)

Accuracy 0.0078125 and loss 0.7900912761688232
Logging Decode
DECODE: 17+536 = 434 (INCORRECT) correct=553
Logging Decode
DECODE: 15+80 = 195__ (INCORRECT) correct=95
Logging Decode
DECODE: 38+569 = 582 (INCORRECT) correct=607
Logging Decode
DECODE: 34+477 = 516 (INCORRECT) correct=511
Logging Decode
DECODE: 73+435 = 542 (INCORRECT) correct=508
Accuracy 0.0390625 and loss 0.5850690603256226
Logging Decode
DECODE: 92+834 = 919 (INCORRECT) correct=926
Logging Decode
DECODE: 66+924 = 968 (INCORRECT) correct=990
Logging Decode
DECODE: 18+39 = 50 (INCORRECT) correct=57
Logging Decode
DECODE: 65+343 = 301 (INCORRECT) correct=408
Logging Decode
DECODE: 33+711 = 740 (INCORRECT) correct=744
Accuracy 0.0625 and loss 0.462399959564209
Logging Decode
DECODE: 78+887 = 959 (INCORRECT) correct=965
Logging Decode
DECODE: 30+789 = 812 (INCORRECT) correct=819
Logging Decode
DECODE: 2+398 = 394 (INCORRECT) correct=400
Logging Decode
DECODE: 74+459 = 522 (INCORRECT) correct=533
Logging Decode
DECODE: 27+9

In [None]:
#rng = random.PRNGKey(0)
#key1, key2 = random.split(rng)
#x = random.normal(key1, (2, 3))
#c0, h0 = nn.LSTMCell.initialize_carry(rng, (2,), 4)
#lstm = nn.LSTMCell()
#(carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x)
#param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params'])
#'''self.assertEqual(param_shapes, {
#    'ii': {'kernel': (3, 4)},
#    'if': {'kernel': (3, 4)},
#    'ig': {'kernel': (3, 4)},
#    'io': {'kernel': (3, 4)},
#    'hi': {'kernel': (4, 4), 'bias': (4,)},
#    'hf': {'kernel': (4, 4), 'bias': (4,)},
#    'hg': {'kernel': (4, 4), 'bias': (4,)},
#    'ho': {'kernel': (4, 4), 'bias': (4,)},
#})'''