<a href="https://colab.research.google.com/github/gauravjain14/All-about-JAX/blob/main/Building_a_Seq2Seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jax
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn
import flax
import numpy as np

import tensorflow as tf
from typing import Any, Tuple

In [3]:
Array = Any

In [4]:
class EncoderLSTM(nn.Module):
  eos_id: int

  @nn.compact
  def __call__(self, carry: Tuple[Array, Array],
               x: Array) -> Tuple[Tuple[Array, Array], Array]:
    """Applies the module."""
    lstm_state, is_eos = carry
    new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
    
    # Pass forward the previous state if EOS has already been reached.
    def select_carried_state(new_state, old_state):
      return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
    # LSTM state is a tuple (c, h).
    carried_lstm_state = tuple(
      select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
    # Update `is_eos`.
    is_eos = jnp.logical_or(is_eos, x[:, self.eos_id])
    return (carried_lstm_state, is_eos), y

  @staticmethod
  def initialize_carry(batch_size: int, hidden_size: int):
    # Use a dummy key since the default state init fn is just zeros.
    return nn.LSTMCell.initialize_carry(
        jax.random.PRNGKey(0), (batch_size,), hidden_size)


In [5]:
class Encoder(nn.Module):
  """LSTM encoder, returning state after finding the EOS token in the input."""
  hidden_size: int
  eos_id: int

  @nn.compact
  def __call__(self, inputs: Array):
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size = inputs.shape[0]
    lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id)
    init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
    # We use the `is_eos` array to determine whether the encoder should carry
    # over the last lstm state, or apply the LSTM cell on the previous state.
    init_is_eos = jnp.zeros(batch_size, dtype=bool)
    init_carry = (init_lstm_state, init_is_eos)

    # Why don't we need the output?
    (final_state, _), _ = lstm(init_carry, inputs)
    return final_state

In [6]:
class Seq2Seq(nn.Module):
  """Sequence-to-sequence class using encoder/decoder architecture.
  Attributes:
    teacher_force: whether to use `decoder_inputs` as input to the decoder at
      every step. If False, only the first input (i.e., the "=" token) is used,
      followed by samples taken from the previous output logits.
    hidden_size: int, the number of hidden dimensions in the encoder and decoder
      LSTMs.
    vocab_size: the size of the vocabulary.
    eos_id: EOS id.
  """
  teacher_force: bool
  hidden_size: int
  vocab_size: int
  eos_id: int = 1

  @nn.compact
  def __call__(self, encoder_inputs: Array,
               decoder_inputs: Array): # -> Tuple[Array, Array]:
    """
    encoder_inputs: [batch_size, max_input_length, vocab_size] padded batch
    of input sequences to encode
    
    decoder_inputs: [batch_size, max_output_length, vocab_size] padded batch
    of expected decoded sequences for teacher forcing."""
    # Encode inputs
    final_encoder_state = Encoder(
        hidden_size=self.hidden_size, eos_id=self.eos_id)(encoder_inputs)
    # Decode outputs
    # logits, predictions = Decoder(
    #    init_state=final_encoder_state,
    #    teacher_force=self.teacher_force,
    #    vocab_size=self.vocab_size)(decoder_inputs[:, -1])

    return final_encoder_state

In [7]:
x = np.array([1, 2, 3, 4])
y = x[:, np.newaxis]
print(x.shape)
print(y.shape)

(4,)
(4, 1)


## Training

In [8]:
!cp drive/MyDrive/Work/Deep\ Learning/JAX/seq2seq/input_pipeline.py .
!pip install clu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [9]:
import functools
from typing import Any, Dict, Tuple

import os
from absl import app
from absl import flags
from absl import logging
from clu import metric_writers
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax

!cp drive/MyDrive/Work/Deep\ Learning/JAX/seq2seq/input_pipeline.py .
import input_pipeline
from input_pipeline import CharacterTable as CTable
from input_pipeline import get_sequence_lengths
from input_pipeline import mask_sequences

Array = Any
FLAGS = flags.FLAGS
PRNGKey = Any

In [10]:
hidden_size = 512
num_train_steps = 10000
batch_size = 128

In [13]:
def train_and_evaluate(workdir: str, learning_rate: jnp.float32 = 1e-6):
  """ As per comments in the Flax source code
  https://github.com/google/flax/blob/main/examples/seq2seq/input_pipeline.py
  max_len_query_digit defaults to 3. This implies the input cannot exceed
  999. Thus, for computing maximum input sequence length, we can have two
  up to 3-digit numbers and a plus sign and an <eos> character """
  ctable = CTable('0123456789+= ', 3)
  rng = jax.random.PRNGKey(0)

  def init_model_params(rng: PRNGKey):
    """ Using model.init() to initialize weights of the model;
    This uses the same approach as shown in  
    https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html for
    Variable initialization, in a lazy manner such that the params are
    initialized without a forward pass. """
    rng1, rng2 = jax.random.split(rng)
    variables = model.init(
        {'params': rng1, 'lstm': rng2},
        jnp.ones(ctable.encoder_input_shape, jnp.float32),
        jnp.ones(ctable.decoder_input_shape, jnp.float32)
    )
    return variables['params']

  def train_step(state: train_state.TrainState, batch: Array,
                 lstm_rng: PRNGKey, eos_id: int) -> Tuple[
                     train_state.TrainState, Dict[str, float]]:
    """ state: State storing the model, params, and updates to the params
      batch: Input batch
      
    """
    # Not exactly sure what this line implies.
    lstm_key = jax.random.fold_in(lstm_rng, state.step)

    def loss_fn(params):
      logits, _ = state.apply_fn({'params': params},
                                 batch['query'], batch['answer'],
                                 rngs={'lstm': lstm_key})

    return None

  model = Seq2Seq(teacher_force=False,
                        hidden_size=hidden_size, eos_id=ctable.eos_id,
                        vocab_size=ctable.vocab_size)
  params = init_model_params(jax.random.split(rng))
  tx = optax.adam(learning_rate)

  # model, params, and tx are used to fill the TrainingState
  # (described in flax.training)
  state = train_state.TrainState.create(apply_fn=model.apply, params=params,
                                        tx=tx)
  
  ## Let's train
  for step in range(num_train_steps):
    batch = ctable.get_batch(batch_size)
  print(model)

In [12]:
train_and_evaluate(None)




Seq2Seq(
    # attributes
    teacher_force = False
    hidden_size = 512
    vocab_size = 15
    eos_id = 1
)
