# Flax seq2seq Example

<a href="https://colab.research.google.com/github/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Demonstration notebook for
https://github.com/google/flax/tree/main/examples/seq2seq


The **Flax Notebook Workflow**:

1. Run the entire notebook end-to-end and check out the outputs.
   - This will open Python files in the right-hand editor!
   - You'll be able to interactively explore metrics in TensorBoard.
2. Change `config` and train for different hyperparameters. Check out the
   updated TensorBoard plots.
3. Update the code in `train.py`. Thanks to `%autoreload`, any changes you
   make in the file will automatically appear in the notebook. Some ideas to
   get you started:
   - Change the model.
   - Log some per-batch metrics during training.
   - Add new hyperparameters to `configs/default.py` and use them in
     `train.py`.
4. At any time, feel free to paste code from `train.py` into the notebook
   and modify it directly there!

## Setup

In [1]:
# Install CLU & latest Flax version from Github.
!pip install -q clu git+https://github.com/google/flax

[?25l[K     |████▌                           | 10kB 24.9MB/s eta 0:00:01[K     |█████████                       | 20kB 23.2MB/s eta 0:00:01[K     |█████████████▌                  | 30kB 17.9MB/s eta 0:00:01[K     |██████████████████              | 40kB 15.5MB/s eta 0:00:01[K     |██████████████████████▍         | 51kB 8.0MB/s eta 0:00:01[K     |███████████████████████████     | 61kB 7.9MB/s eta 0:00:01[K     |███████████████████████████████▍| 71kB 8.9MB/s eta 0:00:01[K     |████████████████████████████████| 81kB 6.2MB/s 
[K     |████████████████████████████████| 122kB 45.5MB/s 
[K     |████████████████████████████████| 92kB 13.9MB/s 
[K     |████████████████████████████████| 61kB 7.9MB/s 
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [2]:
example_directory = 'examples/seq2seq'
editor_relpaths = ('train.py',)

repo, branch = 'https://github.com/google/flax', 'main'

In [3]:
# (If you run this code in Jupyter[lab], then you're already in the
#  example directory and nothing needs to be done.)

#@markdown **Fetch newest Flax, copy example code**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted an any changes are lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will be *persisted*, and even if you re-run the
#@markdown Colab later on, the files will still be the same (you can of course
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files).

if 'google.colab' in str(get_ipython()):
  import os
  os.chdir('/content')
  # Download Flax repo from Github.
  if not os.path.isdir('flaxrepo'):
    !git clone --depth=1 -b $branch $repo flaxrepo
  # Copy example files & change directory.
  mount_gdrive = 'no' #@param ['yes', 'no']
  if mount_gdrive == 'yes':
    DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
    from google.colab import drive
    drive.mount('/content/gdrive')
    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
  else:
    DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
    example_root_path = f'/content/{example_directory}'
    from IPython import display
    display.display(display.HTML(
        f'<h1 style="color:red;" class="blink">{DISCLAIMER}</h1>'))
  if not os.path.isdir(example_root_path):
    os.makedirs(example_root_path)
    !cp -r flaxrepo/$example_directory/* "$example_root_path"
  os.chdir(example_root_path)
  from google.colab import files
  for relpath in editor_relpaths:
    s = open(f'{example_root_path}/{relpath}').read()
    open(f'{example_root_path}/{relpath}', 'w').write(
        f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
    files.view(f'{example_root_path}/{relpath}')

Cloning into 'flaxrepo'...
remote: Enumerating objects: 329, done.[K
remote: Counting objects: 100% (329/329), done.[K
remote: Compressing objects: 100% (291/291), done.[K
remote: Total 329 (delta 58), reused 137 (delta 21), pack-reused 0[K
Receiving objects: 100% (329/329), 1.76 MiB | 8.95 MiB/s, done.
Resolving deltas: 100% (58/58), done.


<IPython.core.display.Javascript object>

In [4]:
# Note : In Colab, above cell changed the working directory.
!pwd

/content/examples/seq2seq


## Imports

In [5]:
from absl import app
app.parse_flags_with_usage(['seq2seq'])

from absl import logging
logging.set_verbosity(logging.INFO)

import jax

In [6]:
# Local imports from current directory - auto reload.
# Any changes you make to train.py will appear automatically.
%load_ext autoreload
%autoreload 2
import train

## Dataset

In [7]:
# Examples are generated on the fly.
list(train.get_examples(5))

[('38+892', '=930'),
 ('19+70', '=89'),
 ('90+293', '=383'),
 ('31+198', '=229'),
 ('43+345', '=388')]

In [8]:
batch = train.get_batch(5)
# A single query (/answer) is one-hot encoded.
batch['query'][0]

array([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32)

In [9]:
# Note how CTABLE encodes PAD=0, EOS=1, '0'=2, '1'=3, ...
train.decode_onehot(batch['query'][:1])

array(['38+293'], dtype='<U6')

## Training

In [None]:
# Get a live update during training - use the "refresh" button!
# (In Jupyter[lab] start "tensorboard" in the local directory instead.)
if 'google.colab' in str(get_ipython()):
  %load_ext tensorboard
  %tensorboard --logdir=./workdirs

In [11]:
import time
workdir = f'./workdirs/{int(time.time())}'

In [12]:
# Train 2k steps & log 20 times.
app.parse_flags_with_usage([
    'seq2seq',
    '--num_train_steps=2000',
    '--decode_frequency=100',
])

['seq2seq']

In [13]:
state = train.train_model(workdir=workdir)

INFO:absl:[0] accuracy=0.000000, loss=1.550575
INFO:absl:DECODE: 10+730 = _3292 (INCORRECT) correct=740
INFO:absl:DECODE: 21+207 =  7 (INCORRECT) correct=228
INFO:absl:DECODE: 41+918 =  (INCORRECT) correct=959
INFO:absl:DECODE: 58+383 = 0673_ (INCORRECT) correct=441
INFO:absl:DECODE: 32+846 = 64576 (INCORRECT) correct=878
INFO:absl:[100] accuracy=0.015625, loss=0.722793
INFO:absl:DECODE: 28+867 = 996 (INCORRECT) correct=895
INFO:absl:DECODE: 47+564 = 661 (INCORRECT) correct=611
INFO:absl:DECODE: 2+948 = 944 (INCORRECT) correct=950
INFO:absl:DECODE: 73+803 = 826 (INCORRECT) correct=876
INFO:absl:DECODE: 99+742 = 824 (INCORRECT) correct=841
INFO:absl:[200] accuracy=0.046875, loss=0.532240
INFO:absl:DECODE: 73+398 = 462 (INCORRECT) correct=471
INFO:absl:DECODE: 98+644 = 732 (INCORRECT) correct=742
INFO:absl:DECODE: 65+893 = 966 (INCORRECT) correct=958
INFO:absl:DECODE: 9+301 = 331 (INCORRECT) correct=310
INFO:absl:DECODE: 52+292 = 331 (INCORRECT) correct=344
INFO:absl:[300] accuracy=0.031

In [42]:
if 'google.colab' in str(get_ipython()):
  #@markdown You can upload the training results directly to https://tensorboard.dev
  #@markdown
  #@markdown Note that everbody with the link will be able to see the data.
  upload_data = 'no' #@param ['yes', 'no']
  if upload_data == 'yes':
    !tensorboard dev upload --one_shot --logdir ./workdirs --name 'Flax examples/seq2seq (Colab)'

Data for the "text" plugin is now uploaded to TensorBoard.dev! Note that uploaded data is public. If you do not want to upload data for this plugin, use the "--plugins" command line argument.

New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/h81jpOlgS5iBJv4MVdznRQ/

[1m[2021-06-29T21:09:10][0m Started scanning logdir.
[1m[2021-06-29T21:09:11][0m Total uploaded: 40 scalars, 0 tensors, 0 binary objects
[1m[2021-06-29T21:09:11][0m Done scanning logdir.


Done. View your TensorBoard at https://tensorboard.dev/experiment/h81jpOlgS5iBJv4MVdznRQ/


## Inference

In [34]:
inputs = train.encode_onehot(['2+40'])
# batch, max_length, vocab_size
inputs.shape

(1, 8, 15)

In [35]:
# Using different random seeds generates different samples.
preds = train.decode(state.params, inputs, jax.random.PRNGKey(0))

In [36]:
train.decode_onehot(preds)

array(['42'], dtype='<U2')