1. Connect to GPU
2. Mount Google Drive
3. Clone the repository
4. cd drive/My Drive/flax/examples/sst

In [None]:
# cd drive/My Drive/flax/examples/sst

In [None]:
# Note: In Colab, above cell changed the working directory.
!pwd

In [None]:
# Install SST-2 dependencies.
!pip install -q -r requirements.txt
!pip install --upgrade git+https://github.com/google/flax.git
!pip install --upgrade jax jaxlib

In [None]:
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import re
import tensorflow as tf

def load_sst_data(path, lower=True):
  """Loads an SST file as a TF dataset."""
  data = {'text': [], 'label': []}
  with open(path, 'r', encoding='utf8') as f:
    for line in f:
      line = line.strip()
      line = line.lower() if lower else line
      # We skip the below data fixes since they weren't used by the paper.
      # line = line.replace("\\", "")
      # line = re.sub("\\\\", "", line)
      tokens = re.findall(r"\([0-9] ([^\(\)]+)\)", line)
      label = int(line[1])
      data['text'].append(' '.join(tokens))
      data['label'].append(label)          
    return tf.data.Dataset.from_tensor_slices(data)

In [None]:
train_dataset = load_sst_data('data/trees/train.txt')
val_dataset = load_sst_data('data/trees/dev.txt')
test_dataset = load_sst_data('data/trees/test.txt')
next(iter(train_dataset))

In [None]:
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Tuple
from jax import numpy as jnp
from flax.linen import recurrent
from sst2.models import flip_sequences
import jax 
from flax import linen as nn
import functools

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Array = Any

KERNEL_INIT = nn.linear.default_kernel_init
RECURRENT_INIT = nn.initializers.orthogonal

class MyLSTMCell(recurrent.RNNCellBase):
  """A PyTorch-compatible LSTM cell."""
  gate_fn: Callable[..., Any] = nn.sigmoid
  activation_fn: Callable[..., Any] = nn.tanh
  kernel_init: Callable[..., Array] = KERNEL_INIT
  recurrent_kernel_init: Callable[..., Array] = RECURRENT_INIT()
  bias_init: Callable[..., Array] = nn.initializers.zeros
  dtype: Dtype = jnp.float32
  param_dtype: Dtype = jnp.float32

  @nn.compact
  def __call__(self, carry, inputs):
    """Performs a single time step of the cell.
    Args:
      carry: the hidden state of the LSTM cell, a tuple (c, h),
        initialized using `MyLSTMCell.initialize_carry`.
      inputs: an ndarray with the input for the current time step.
        All dimensions except the final are considered batch dimensions.
    Returns:
      A tuple with the new carry (c', h') and the output (h').
    """
    c, h = carry
    features = h.shape[-1]
    
    # Compute [h_i, h_f, h_g, h_o] at the same time for better performance.
    dense_h = nn.Dense(
        features=features * 4,
        use_bias=True,
        kernel_init=self.recurrent_kernel_init,
        bias_init=self.bias_init,
        name='h', 
        dtype=self.dtype, 
        param_dtype=self.param_dtype)(h)
 
    # Compute [i_i, i_f, i_g, i_o] at the same time for better performance.
    dense_i = nn.Dense(
        features=features * 4,
        use_bias=True,  # dense_h already has a bias, but we follow PyTorch.
        kernel_init=self.kernel_init,
        bias_init=self.bias_init,
        name='i', 
        dtype=self.dtype, 
        param_dtype=self.param_dtype)(inputs)

    # We sum each h_{i,f,g,o} with each i_{i,f,g,o} already now for performance.
    summed_combined_projections = dense_i + dense_h

    # Split into i = i_i + h_i, f = i_f + h_f, g = i_g + h_h, o = i_o + h_o.
    i, g, f, o = jnp.split(summed_combined_projections, 4, axis=-1)

    i = self.gate_fn(i)
    f = self.gate_fn(f)
    g = self.activation_fn(g)
    o = self.gate_fn(o)

    new_c = f * c + i * g
    new_h = o * self.activation_fn(new_c)
    return (new_c, new_h), new_h

  @staticmethod
  def initialize_carry(rng, batch_dims, size, init_fn=nn.initializers.zeros):
    """initialize the RNN cell carry.
    Args:
      rng: random number generator passed to the init_fn.
      batch_dims: a tuple providing the shape of the batch dimensions.
      size: the size or number of features of the memory.
      init_fn: initializer function for the carry.
    Returns:
      An initialized carry for the given RNN cell.
    """
    key1, key2 = jax.random.split(rng)
    mem_shape = batch_dims + (size,)
    return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


class LSTM(nn.Module):
  """A simple unidirectional LSTM."""

  @functools.partial(
      nn.transforms.scan,
      variable_broadcast='params',
      in_axes=1, out_axes=1,
      split_rngs={'params': False})
      
  @nn.compact
  def __call__(self, carry, x):
    return MyLSTMCell(name='cell')(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return MyLSTMCell.initialize_carry(
        jax.random.PRNGKey(0), batch_dims, hidden_size)


class BiLSTM(nn.Module):
  """A simple bi-directional LSTM."""
  hidden_size: int

  @nn.compact
  def __call__(self, inputs, lengths):
    batch_size = inputs.shape[0]

    # Forward LSTM.
    initial_state = LSTM.initialize_carry((batch_size,), self.hidden_size)
    _, forward_outputs = LSTM(name='lstm_fwd')(initial_state, inputs)
    forward_final = forward_outputs[jnp.arange(inputs.shape[0]), lengths - 1]

    # Backward LSTM.
    reversed_inputs = flip_sequences(inputs, lengths)
    initial_state = LSTM.initialize_carry((batch_size,), self.hidden_size)
    _, backward_outputs = LSTM(name='lstm_bwd')(initial_state, reversed_inputs)
    backward_final = backward_outputs[jnp.arange(inputs.shape[0]), lengths - 1]

    # Concatenate the forward and backward representations.
    # `outputs` is shaped [B, T, 2*D] and contains all (h) vectors across time.
    backward_outputs = flip_sequences(backward_outputs, lengths)
    outputs = jnp.concatenate([forward_outputs, backward_outputs], -1)

    return outputs, (forward_final, backward_final)


class BiLSTMClassifier(nn.Module):
  hidden_size: int
  embedding_size: int
  vocab_size: int
  output_size: int

  @nn.compact
  def __call__(self, inputs, lengths):
    """Embeds and encodes the inputs, and then predicts."""
    embedded = nn.Embed(
        self.vocab_size, 
        features=self.embedding_size, 
        name='embedder')(
            inputs)
    _, (forward_final, backward_final) = BiLSTM(
        self.hidden_size, 
        name='bilstm')(
            embedded, lengths)
    forward_output = nn.Dense(
        self.output_size, use_bias=False, name='output_layer_fwd')(
            forward_final)
    backward_output = nn.Dense(
        self.output_size, use_bias=False, name='output_layer_bwd')(
            backward_final)
    return forward_output + backward_output  # Logits.

In [None]:
# SANITY CHECK
model = BiLSTMClassifier(hidden_size=60, embedding_size=60, vocab_size=19538, output_size=5)
x = np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
lengths = np.array([1, 2, 3])
variables = model.init(jax.random.PRNGKey(0), x, lengths)
outputs = model.apply(variables, x, lengths)
print('outputs shape:', jax.tree_map(np.shape, outputs))
print('outputs:', outputs)

In [None]:
from sst2.train import *
# update train_and_evaluate to take custom datasets and vocab as input
def train_and_evaluate(train_dataset, val_dataset, vocab, config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
  """Execute model training and evaluation loop.
  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  Returns:
    The final train state that includes the trained parameters.
  """
  # Use datasets.
  train_batches = train_dataset.get_batches(config.batch_size)
  eval_batches = val_dataset.get_batches(batch_size=config.batch_size)

  # Keep track of vocab size in the config so that the embedder knows it.
  config.vocab_size = len(vocab)

  # Compile step functions.
  train_step_fn = jax.jit(train_step)
  eval_step_fn = jax.jit(eval_step)

  # Create model and a state that contains the parameters.
  rng = jax.random.PRNGKey(config.seed)
  model = model_from_config(config)
  state = create_train_state(rng, config, model)

  summary_writer = tensorboard.SummaryWriter(workdir)
  summary_writer.hparams(dict(config))

  # Main training loop.
  logging.info('Starting training...')
  for epoch in range(1, config.num_epochs + 1):

    # Train for one epoch.
    rng, epoch_rng = jax.random.split(rng)
    rngs = {'dropout': epoch_rng}
    state, train_metrics = train_epoch(
        train_step_fn, state, train_batches, epoch, rngs)

    # Evaluate current model on the validation data.
    eval_metrics = evaluate_model(eval_step_fn, state, eval_batches, epoch)

    # Write metrics to TensorBoard.
    summary_writer.scalar('train_loss', train_metrics.loss, epoch)
    summary_writer.scalar(
        'train_accuracy',
        train_metrics.accuracy * 100,
        epoch)
    summary_writer.scalar('eval_loss', eval_metrics.loss, epoch)
    summary_writer.scalar(
        'eval_accuracy',
        eval_metrics.accuracy * 100,
        epoch)

  summary_writer.flush()
  return state

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tftext
from typing import Iterable, Sequence
from itertools import chain
from sst2 import vocabulary
from sst2.input_pipeline import TextDataset, vocab_to_hashtable, AUTOTUNE, text

def get_tokenized_sequences(
        dataset: tf.data.Dataset,
        tokenizer: tftext.Tokenizer = tftext.WhitespaceTokenizer(),
        input_key: str = 'text') -> Iterable[Sequence[bytes]]:
  """Returns tokenized sequences for vocabulary building."""
  dataset = dataset.map(
      lambda example: tokenizer.tokenize(example[input_key]),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  for sentence in tfds.as_numpy(dataset):
    yield sentence

# Iterate over the training, test, and dev datasets. (In this order!)
# Add each word that you see to the vocabulary if you haven't seen it before.
# You should end up with 19538 words if you do this correctly. 
# Verify that your vocab matches the serialized vocab from the repo https://github.com/ArrasL/LRP_for_LSTM.
tokenized_sequences_train = get_tokenized_sequences(train_dataset)
tokenized_sequences_test = get_tokenized_sequences(test_dataset)
tokenized_sequences_val = get_tokenized_sequences(val_dataset)

generator = chain(tokenized_sequences_train, tokenized_sequences_val, tokenized_sequences_test)

# Builds the vocabulary from the tokenized sequences.
# A token needs to appear at least 3 times to be in the vocabulary. You can
# play with this. It is there to make sure we don't overfit on rare words.
vocab = vocabulary.Vocabulary(
    tokenized_sequences=generator, min_freq=1)
vocab.save('vocab.txt')

class TextDatasetSST(TextDataset):
    def __init__(self, dataset: tf.data.Dataset, vocab_path: str = 'vocab.txt',
               tokenizer: text.Tokenizer = text.WhitespaceTokenizer()):
      """Initializes the SST data source."""
      self.dataset = dataset
      self.text_feature_name = 'text'
      self.label_feature_name = 'label'

      # Load the vocabulary.
      self.vocab = vocabulary.Vocabulary(vocab_path=vocab_path)

      # Convert the sentences to sequences of token IDs and compute length.
      self.tokenizer = tokenizer
      self.tf_vocab = vocab_to_hashtable(self.vocab, unk_idx=self.vocab.unk_idx)
      self.examples = self.dataset.map(
          self.prepare_example, num_parallel_calls=AUTOTUNE).cache()
    
    @property
    def padded_shapes(self):
      """The padded shapes used by batching functions."""
      # None means variable length; pads to the longest sequence in the batch.
      return {'token_ids': [None], 'label': [], 'length': []}

In [None]:
train_ds = TextDatasetSST(train_dataset)
val_ds = TextDatasetSST(val_dataset)

In [None]:
# Get a live update during training - use the "refresh" button!
# (In Jupyter[lab] start "tensorboard" in the local directory instead.)
if 'google.colab' in str(get_ipython()):
  %load_ext tensorboard
  %tensorboard --logdir=.

In [None]:
import time
from configs import default as config_lib
config = config_lib.get_config()
model_name = 'sst_bilstm'
start_time = time.time()
optimizer = train_and_evaluate(train_ds, val_ds, config, workdir=f'./models/{model_name}')
logging.info('Walltime: %f s', time.time() - start_time)

In [None]:
if 'google.colab' in str(get_ipython()):
  #@markdown You can upload the training results directly to https://tensorboard.dev
  #@markdown
  #@markdown Note that everbody with the link will be able to see the data.
  upload_data = 'yes' #@param ['yes', 'no']
  if upload_data == 'yes':
    !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/sst'