# Flax seq2seq Example

<a href="https://colab.research.google.com/github/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Demonstration notebook for
https://github.com/google/flax/tree/main/examples/seq2seq


The **Flax Notebook Workflow**:

1. Run the entire notebook end-to-end and check out the outputs.
   - This will open Python files in the right-hand editor!
   - You'll be able to interactively explore metrics in TensorBoard.
2. Change some of the hyperparameters in the command-line flags in `train.py` for different hyperparameters. Check out the updated TensorBoard plots.
3. Update the code in `train.py`, `models.py`, and `input_pipeline.py`. 
   Thanks to `%autoreload`, any changes you make in the file will 
   automatically appear in the notebook. Some ideas to get you started:
   - Change the model.
   - Log some per-batch metrics during training.
   - Add new hyperparameters to `models.py` and use them in `train.py`.
   - Train on a different vocabulary by initializing `CharacterTable` with a
     different character set.
4. At any time, feel free to paste code from the source code into the notebook
   and modify it directly there!

## Setup

In [None]:
# Install CLU & Flax.
!pip install -q clu flax

[K     |████████████████████████████████| 77 kB 3.1 MB/s 
[K     |████████████████████████████████| 176 kB 30.2 MB/s 
[K     |████████████████████████████████| 77 kB 5.2 MB/s 
[K     |████████████████████████████████| 136 kB 45.5 MB/s 
[K     |████████████████████████████████| 65 kB 2.8 MB/s 
[K     |████████████████████████████████| 462 kB 44.3 MB/s 
[?25h  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone


In [None]:
example_directory = 'examples/seq2seq'
editor_relpaths = ('train.py', 'input_pipeline.py', 'models.py')

repo, branch = 'https://github.com/google/flax', 'main'

In [None]:
# (If you run this code in Jupyter[lab], then you're already in the
#  example directory and nothing needs to be done.)

#@markdown **Fetch newest Flax, copy example code**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted an any changes are lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will be *persisted*, and even if you re-run the
#@markdown Colab later on, the files will still be the same (you can of course
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files).

if 'google.colab' in str(get_ipython()):
  import os
  os.chdir('/content')
  # Download Flax repo from Github.
  if not os.path.isdir('flaxrepo'):
    !git clone --depth=1 -b $branch $repo flaxrepo
  # Copy example files & change directory.
  mount_gdrive = 'no' #@param ['yes', 'no']
  if mount_gdrive == 'yes':
    DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
    from google.colab import drive
    drive.mount('/content/gdrive')
    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
  else:
    DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
    example_root_path = f'/content/{example_directory}'
    from IPython import display
    display.display(display.HTML(
        f'<h1 style="color:red;" class="blink">{DISCLAIMER}</h1>'))
  if not os.path.isdir(example_root_path):
    os.makedirs(example_root_path)
    !cp -r flaxrepo/$example_directory/* "$example_root_path"
  os.chdir(example_root_path)
  from google.colab import files
  for relpath in editor_relpaths:
    s = open(f'{example_root_path}/{relpath}').read()
    open(f'{example_root_path}/{relpath}', 'w').write(
        f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
    files.view(f'{example_root_path}/{relpath}')

Cloning into 'flaxrepo'...
remote: Enumerating objects: 349, done.[K
remote: Counting objects: 100% (349/349), done.[K
remote: Compressing objects: 100% (286/286), done.[K
remote: Total 349 (delta 63), reused 220 (delta 51), pack-reused 0[K
Receiving objects: 100% (349/349), 2.12 MiB | 13.39 MiB/s, done.
Resolving deltas: 100% (63/63), done.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# Note : In Colab, above cell changed the working directory.
!pwd

/content/examples/seq2seq


## Imports

In [None]:
from absl import app
app.parse_flags_with_usage(['seq2seq'])

from absl import logging
logging.set_verbosity(logging.INFO)

import jax

In [None]:
# Local imports from current directory - auto reload.
# Any changes you make to the three imported files will appear automatically.
%load_ext autoreload
%autoreload 2
import input_pipeline
import models
import train

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Dataset

In [None]:
# Examples are generated on the fly.
ctable = input_pipeline.CharacterTable('0123456789+= ')
list(ctable.generate_examples(5))

[('72+789', '=861'),
 ('58+858', '=916'),
 ('77+358', '=435'),
 ('99+264', '=363'),
 ('94+115', '=209')]

In [None]:
batch = ctable.get_batch(5)
# A single query (/answer) is one-hot encoded.
batch['query'][0]

array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32)

In [None]:
# Note how CTABLE encodes PAD=0, EOS=1, '0'=2, '1'=3, ...
ctable.decode_onehot(batch['query'][:1])

array(['1+243'], dtype='<U5')

## Training

In [None]:
# Get a live update during training - use the "refresh" button!
# (In Jupyter[lab] start "tensorboard" in the local directory instead.)
if 'google.colab' in str(get_ipython()):
  %load_ext tensorboard
  %tensorboard --logdir=./workdirs

In [None]:
import time
workdir = f'./workdirs/{int(time.time())}'

In [None]:
# Train 2k steps & log 20 times.
app.parse_flags_with_usage([
    'seq2seq',
    '--num_train_steps=2000',
    '--decode_frequency=100',
])

['seq2seq']

In [None]:
state = train.train_and_evaluate(workdir=workdir)

INFO:absl:[100] accuracy=0.015625, loss=0.6936355233192444
INFO:absl:DECODE: 96+964 = 1002 (INCORRECT) correct=1060
INFO:absl:DECODE: 71+545 = 608 (INCORRECT) correct=616
INFO:absl:DECODE: 42+729 = 730 (INCORRECT) correct=771
INFO:absl:DECODE: 28+588 = 684 (INCORRECT) correct=616
INFO:absl:DECODE: 39+648 = 618 (INCORRECT) correct=687
INFO:absl:[200] accuracy=0.03125, loss=0.5422528982162476
INFO:absl:DECODE: 18+70 = 70 (INCORRECT) correct=88
INFO:absl:DECODE: 43+123 = 145 (INCORRECT) correct=166
INFO:absl:DECODE: 72+406 = 460 (INCORRECT) correct=478
INFO:absl:DECODE: 53+443 = 492 (INCORRECT) correct=496
INFO:absl:DECODE: 74+844 = 936 (INCORRECT) correct=918
INFO:absl:[300] accuracy=0.0703125, loss=0.462927907705307
INFO:absl:DECODE: 40+598 = 643 (INCORRECT) correct=638
INFO:absl:DECODE: 2+72 = 75 (INCORRECT) correct=74
INFO:absl:DECODE: 70+742 = 814 (INCORRECT) correct=812
INFO:absl:DECODE: 22+943 = 963 (INCORRECT) correct=965
INFO:absl:DECODE: 85+890 = 975 (CORRECT)
INFO:absl:[400] ac

In [None]:
if 'google.colab' in str(get_ipython()):
  #@markdown You can upload the training results directly to https://tensorboard.dev
  #@markdown
  #@markdown Note that everbody with the link will be able to see the data.
  upload_data = 'yes' #@param ['yes', 'no']
  if upload_data == 'yes':
    !tensorboard dev upload --one_shot --logdir ./workdirs --name 'Flax examples/seq2seq (Colab)'


***** TensorBoard Uploader *****

This will upload your TensorBoard logs to https://tensorboard.dev/ from
the following directory:

./workdirs

This TensorBoard will be visible to everyone. Do not upload sensitive
data.

Your use of this service is subject to Google's Terms of Service
<https://policies.google.com/terms> and Privacy Policy
<https://policies.google.com/privacy>, and TensorBoard.dev's Terms of Service
<https://tensorboard.dev/policy/terms/>.

This notice will not be shown again while you are logged into the uploader.
To log out, run `tensorboard dev auth revoke`.

Continue? (yes/NO) yes

Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&state=IjociK9llsm6dSiC1TDFvFmJksFy49&prompt=consent&access_type=offline
En

## Inference

In [None]:
inputs = ctable.encode_onehot(['2+40'])
# batch, max_length, vocab_size
inputs.shape

(1, 8, 15)

In [None]:
# Using different random seeds generates different samples.
preds = train.decode(state.params, inputs, jax.random.PRNGKey(0), ctable)

In [None]:
ctable.decode_onehot(preds)

array(['42'], dtype='<U2')