<a href="https://colab.research.google.com/github/gauravjain14/All-about-JAX/blob/main/Building_a_Seq2Seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jax
!pip install flax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.6.0-py3-none-any.whl (180 kB)
[K     |████████████████████████████████| 180 kB 7.5 MB/s 
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 51.5 MB/s 
[?25hCollecting PyYAML>=5.4.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 48.4 MB/s 
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 55.7 MB/s 
Collecting jax>=0.3.16
  Downloading jax-0.3.16.tar.gz (1.0 MB)
[K     |████████████████████████████████| 1.0 MB 35.6 MB/s 
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn
import flax
import numpy as np

import tensorflow as tf
from typing import Any, Tuple

In [3]:
Array = Any

In [4]:
class EncoderLSTM(nn.Module):
  eos_id: int

  @nn.compact
  def __call__(self, carry: Tuple[Array, Array],
               x: Array) -> Tuple[Tuple[Array, Array], Array]:
    """Applies the module."""
    lstm_state, is_eos = carry
    new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
    
    # Pass forward the previous state if EOS has already been reached.
    def select_carried_state(new_state, old_state):
      return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
    # LSTM state is a tuple (c, h).
    carried_lstm_state = tuple(
      select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
    # Update `is_eos`.
    is_eos = jnp.logical_or(is_eos, x[:, self.eos_id])
    return (carried_lstm_state, is_eos), y

  @staticmethod
  def initialize_carry(batch_size: int, hidden_size: int):
    # Use a dummy key since the default state init fn is just zeros.
    return nn.LSTMCell.initialize_carry(
        jax.random.PRNGKey(0), (batch_size,), hidden_size)


In [8]:
class Encoder(nn.Module):
  """LSTM encoder, returning state after finding the EOS token in the input."""
  hidden_size: int
  eos_id: int

  @nn.compact
  def __call__(self, inputs: Array):
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size = inputs.shape[0]
    lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id)
    init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
    # We use the `is_eos` array to determine whether the encoder should carry
    # over the last lstm state, or apply the LSTM cell on the previous state.
    init_is_eos = jnp.zeros(batch_size, dtype=bool)
    init_carry = (init_lstm_state, init_is_eos)

    # Why don't we need the output?
    (final_state, _), _ = lstm(init_carry, inputs)
    return final_state

In [10]:
class Seq2Seq(nn.Module):
  """Sequence-to-sequence class using encoder/decoder architecture.
  Attributes:
    teacher_force: whether to use `decoder_inputs` as input to the decoder at
      every step. If False, only the first input (i.e., the "=" token) is used,
      followed by samples taken from the previous output logits.
    hidden_size: int, the number of hidden dimensions in the encoder and decoder
      LSTMs.
    vocab_size: the size of the vocabulary.
    eos_id: EOS id.
  """
  teacher_force: bool
  hidden_size: int
  vocab_size: int
  eos_id: int = 1

  @nn.compact
  def __call__(self, encoder_inputs: Array,
               decoder_inputs: Array): # -> Tuple[Array, Array]:
    """
    encoder_inputs: [batch_size, max_input_length, vocab_size] padded batch
    of input sequences to encode
    
    decoder_inputs: [batch_size, max_output_length, vocab_size] padded batch
    of expected decoded sequences for teacher forcing."""
    # Encode inputs
    final_encoder_state = Encoder(
        hidden_size=self.hidden_size, eos_id=self.eos_id)(encoder_inputs)
    # Decode outputs
    # logits, predictions = Decoder(
    #    init_state=final_encoder_state,
    #    teacher_force=self.teacher_force,
    #    vocab_size=self.vocab_size)(decoder_inputs[:, -1])

    return final_encoder_state

In [7]:
x = np.array([1, 2, 3, 4])
y = x[:, np.newaxis]
print(x.shape)
print(y.shape)

(4,)
(4, 1)


## Training

In [17]:
!cp drive/MyDrive/Work/Deep\ Learning/JAX/seq2seq/input_pipeline.py .
!pip install clu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting clu
  Downloading clu-0.0.7-py3-none-any.whl (92 kB)
[K     |████████████████████████████████| 92 kB 735 kB/s 
Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 5.6 MB/s 
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=6e1e2e6f11834eebfa895e2c95a46d82ca5e1063d799fbff8c828310d2b0e409
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff19791081b705879561b715a8341a856a3bbd2
Successfully built ml-collections
Installing collected packages: ml-collections, clu
Successfully installed clu-0.0.7 ml-collections-0.1.1


In [25]:
import functools
from typing import Any, Dict, Tuple

from absl import app
from absl import flags
from absl import logging
from clu import metric_writers
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax

!python input_pipeline.py
# from input_pipeline import CharacterTable as CTable
# from input_pipeline import get_sequence_length
# from input_pipeline import mask_sequences