# Flax SST-2 Example

<a href="https://colab.research.google.com/github/google/flax/blob/main/examples/sst2/sst2.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Demonstration notebook for
https://github.com/google/flax/tree/main/examples/sst2

**Before you start:** Select Runtime -> Change runtime type -> GPU.

The **Flax Notebook Workflow**:

1. Run the entire notebook end-to-end and check out the outputs.
   - This will open Python files in the right-hand editor!
   - You'll be able to interactively explore metrics in TensorBoard.
2. Change `config` and train for different hyperparameters. Check out the
   updated TensorBoard plots.
3. Update the code in `train.py`. Thanks to `%autoreload`, any changes you
   make in the file will automatically appear in the notebook. Some ideas to
   get you started:
   - Change the model.
   - Log some per-batch metrics during training.
   - Add new hyperparameters to `configs/default.py` and use them in
     `train.py`.
4. At any time, feel free to paste code from `train.py` into the notebook
   and modify it directly there!

## Setup

In [None]:
example_directory = 'examples/sst2'
editor_relpaths = ('configs/default.py', 'train.py', 'models.py')

In [None]:
# (If you run this code in Jupyter[lab], then you're already in the
#  example directory and nothing needs to be done.)

#@markdown **Fetch newest Flax, copy example code**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted an any changes are lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will be *persisted*, and even if you re-run the
#@markdown Colab later on, the files will still be the same (you can of course
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files).

if 'google.colab' in str(get_ipython()):
  import os
  os.chdir('/content')
  # Download Flax repo from Github.
  if not os.path.isdir('flaxrepo'):
    !git clone --depth=1 https://github.com/google/flax flaxrepo
  # Copy example files & change directory.
  mount_gdrive = 'no' #@param ['yes', 'no']
  if mount_gdrive == 'yes':
    DISCLAIMER = 'Note: Editing in your Google Drive, changes will persist.'
    from google.colab import drive
    drive.mount('/content/gdrive')
    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
  else:
    DISCLAIMER = 'WARNING: Editing in VM - changes lost after reboot!!'
    example_root_path = f'/content/{example_directory}'
    from IPython import display
    display.display(display.HTML(
        f'<h1 style="color:red;" class="blink">{DISCLAIMER}</h1>'))
  if not os.path.isdir(example_root_path):
    os.makedirs(example_root_path)
    !cp -r flaxrepo/$example_directory/* "$example_root_path"
  os.chdir(example_root_path)
  from google.colab import files
  for relpath in editor_relpaths:
    s = open(f'{example_root_path}/{relpath}').read()
    open(f'{example_root_path}/{relpath}', 'w').write(
        f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
    files.view(f'{example_root_path}/{relpath}')

In [None]:
# Note: In Colab, above cell changed the working directory.
!pwd

In [None]:
# Install SST-2 dependencies.
!pip install -q -r requirements.txt

## Imports / Helpers

In [None]:
# If you want to use TPU instead of GPU, you need to run this to make it work.
try:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
except KeyError:
  print('\n### NO TPU CONNECTED - USING CPU or GPU ###\n')
  import os
  os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.devices()

In [None]:
from absl import logging
import flax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import time
logging.set_verbosity(logging.INFO)

# Make sure the GPU is for JAX, not for TF.
tf.config.experimental.set_visible_devices([], 'GPU')

In [None]:
# Local imports from current directory - auto reload.
# Any changes you make to train.py will appear automatically.
%load_ext autoreload
%autoreload 2
import train
import models
import vocabulary
import input_pipeline
from configs import default as config_lib
config = config_lib.get_config()

## Dataset

In [None]:
# Get datasets. 
# If you get an error you need to install tensorflow_datasets from Github.
train_dataset = input_pipeline.TextDataset(split='train')
eval_dataset = input_pipeline.TextDataset(split='validation')

## Training

In [None]:
# Get a live update during training - use the "refresh" button!
# (In Jupyter[lab] start "tensorboard" in the local directory instead.)
if 'google.colab' in str(get_ipython()):
  %load_ext tensorboard
  %tensorboard --logdir=.

In [None]:
config.num_epochs = 10
model_name = 'bilstm'
start_time = time.time()
optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')
logging.info('Walltime: %f s', time.time() - start_time)

In [None]:
if 'google.colab' in str(get_ipython()):
  #@markdown You can upload the training results directly to https://tensorboard.dev
  #@markdown
  #@markdown Note that everbody with the link will be able to see the data.
  upload_data = 'yes' #@param ['yes', 'no']
  if upload_data == 'yes':
    !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'