## Using jax_tpu_embedding (Single Host Pjit)

This colab is to demonstrate how to use jax_tpu_embedding for training large embeddings in Jax.
This example using embedding lookup activation results as input, and train on target.


In [None]:
from absl import logging
import functools
from typing import Union

from flax.training import common_utils
from flax.training.train_state import TrainState
import jax
from jax.experimental import jax2tf
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf

from jax_tpu_embedding import input_utils
from jax_tpu_embedding import tpu_embedding_utils
from jax_tpu_embedding import tpu_embedding as jte


#### 0. Initialize TPU system for jax_tpu_embedding prerequisites

In [None]:
# Note: TPUEmbedding user needs to call init_tpu_system in the beginning of program.
tpu_embedding_utils.init_tpu_system()

#### 1. Define Example Dense Model

Dense model on TPU device is a simple MLP layer.

In [None]:
import flax.linen as nn

Array = Union[jnp.ndarray, jnp.DeviceArray]
Initializer = jax.nn.initializers.Initializer

class MLPLayers(nn.Module):
  """Create mlp layers."""
  hidden_dim: int
  num_hidden_layers: int
  dropout: float
  num_classes: int
  kernel_init: Initializer = nn.initializers.glorot_uniform()
  bias_init: Initializer = nn.initializers.zeros

  @nn.compact
  def __call__(self, x: Array, is_training: bool = False) -> Array:
    for _ in range(self.num_hidden_layers):
      x = nn.Dense(
          features=self.hidden_dim,
          kernel_init=self.kernel_init,
          bias_init=self.bias_init)(
              x)
      x = nn.relu(x)

    if is_training:
      x = nn.Dropout(rate=self.dropout)(x, deterministic=False)
    x = nn.Dense(features=self.num_classes, bias_init=self.bias_init)(x)
    return x

##### Define one hot targets conversion

In [None]:
def compute_one_hot_targets(targets: Array, num_classes: int,
                            on_value: float) -> Array:
  """Compute one hot encoded targets.

  Args:
    targets: An array of target value.
    num_classes: number of classes to one-hot encoding.
    on_value: Value to fill to non-zero locations.
  Returns:
    An array of one-hot encoded targets.
  """
  one_hot_targets = common_utils.onehot(targets, num_classes, on_value=on_value)
  one_hot_targets = jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=1),
                                           one_hot_targets)
  return one_hot_targets


@jax.vmap
def categorical_cross_entropy_loss(logits: Array, one_hot_targets: Array):
  return -jnp.sum(one_hot_targets * nn.log_softmax(logits), axis=-1)

#### 2. Create dummy sample inputs

We have two `watches` and `watches_targets` in dummy inputs.
* Dense model takes embedding lookup results of `watches` and use `watches_targets` one hot target to train model.

In [None]:
NUM_TARGET_IDS = 5
NUM_WATCHES = 10

def dummy_dataset(global_batch_size: int, vocab_size: int, num_classes: int, seed: int =123):
  rng_state = np.random.RandomState(seed=seed)

  def _create_feature():
    watches = rng_state.randint(low=0, high=vocab_size,
                                size=NUM_WATCHES * global_batch_size)
    watches = tf.sparse.from_dense(watches.reshape(
        [global_batch_size, NUM_WATCHES]))
    targets = rng_state.randint(low=0, high=num_classes,
                                size=NUM_TARGET_IDS * global_batch_size)
    targets = tf.convert_to_tensor(
        targets.reshape([global_batch_size, NUM_TARGET_IDS]))
    return ({
        'watches': tf.sparse.reset_shape(
            watches, new_shape=[global_batch_size, vocab_size]),
    }, {
        'watches_target': tf.cast(targets, dtype=tf.float32),
    })
  ds = tf.data.Dataset.from_tensors(_create_feature())
  ds = ds.repeat()
  return ds

#### 3. Create Embedding Layer

User needs to define feature configuration, it requires:
* table to lookup for given feature.
* output_shape

##### Function to build feature configuration

In [None]:
import math

def build_embedding_configs(batch_size_per_device: int,
                            embedding_dimension: int,
                            vocab_size: int):
  """Create feature configurations for embedding layer.

  Args:
    batch_size_per_device: batch size of inputs to equeue.
    embedding_dimension: dimension size of embedding table.
    vocab_size: vocabulary size of embedding table.
  Returns:
    A dictionary of feature configurations.
  """
  feature_configs = {
      'watches': tf.tpu.experimental.embedding.FeatureConfig(
          table=tf.tpu.experimental.embedding.TableConfig(
              vocabulary_size=vocab_size,
              dim=embedding_dimension,
              initializer=tf.initializers.TruncatedNormal(
                  mean=0.0, stddev=1 / math.sqrt(embedding_dimension)),
                  combiner='mean'),
          output_shape=[batch_size_per_device])
  }
  return feature_configs

##### Setup flags

In [None]:
flags = dict(
    global_batch_size=16,
    embedding_dimension=64,
    hidden_layer_dimension=32,
    num_hidden_layers=1,
    vocab_size=16,
    num_classes=4,
    learning_rate=1.0,
    dropout=0.5,
    num_targets=NUM_TARGET_IDS,
    is_training=True)

##### Create and Initialize A TPUEmbedding Layer

* Why pjit user needs TPUEmbedding SPMD?
`pjit` is the API exposed for XLA SPMD partitioner. To align with that, pjit user needs TPUEmbedding SPMD to enable XLA sharding annotation.
*What to set for `cores_per_replica`?
For `pjit` model parallelism user, it is the number of tensor cores for each model replica. For `pjit` data parallelism user, it needs to be set to `jax.device_count()` or `jax.local_device_count()` for single host user only.

In [None]:
batch_size_per_device = flags['global_batch_size'] // jax.device_count()

feature_configs = build_embedding_configs(
      batch_size_per_device=batch_size_per_device,
      embedding_dimension=flags['embedding_dimension'],
      vocab_size=flags['vocab_size'],
      )

embedding_optimizer = tf.tpu.experimental.embedding.Adagrad(
    learning_rate=flags['learning_rate'])
tpu_embedding_layer = jte.TPUEmbedding(
    feature_configs=feature_configs, optimizer=embedding_optimizer, 
    # Pjit user must set `cores_per_replica` to enable TPUEmbedding SPMD
    cores_per_replica=jax.local_device_count())

In [None]:
# Must call initialize_tpu_embedding to configure TPUEmbedding
tpu_embedding_layer.initialize_tpu_embedding()

# Call load_embedding_tables to initialize embedding tables.
tpu_embedding_layer.load_embedding_tables()

#### Input pipeline

We have two inputs `watches` and `watches_targets`. User may want to use data parallelism aligns TensorCores.

For this example, `watches` is input data to enqueue on CPU to be processed by
TPUEmbedding hostsoftware. `watches_targets` is input data to TensorCore for dense
model.

Therefore we split input data for host and devices by `split_and_prefetch_to_host_and_devices`.

##### Create Global Mesh and Partition Specs

In [None]:
from jax.sharding import Mesh
from typing import Any, Dict, Sequence, Tuple


def create_global_mesh(mesh_shape: Tuple[int, ...],
                       axis_names: Sequence[jax.pxla.MeshAxisName]) -> Mesh:
  size = np.prod(mesh_shape)
  if len(jax.devices()) < size:
    raise ValueError(f'Test requires {size} global devices.')
  devices = sorted(jax.devices(), key=lambda d: d.id)
  mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
  global_mesh = Mesh(mesh_devices, axis_names)
  return global_mesh

In [None]:
num_devices = jax.device_count()
mesh_axis_names = ('x',)

global_mesh = create_global_mesh((num_devices,), mesh_axis_names)

In [None]:
from jax.sharding import PartitionSpec

partition_spec = PartitionSpec(mesh_axis_names)

##### Build Dummy Input Data Iterator with Data Parallelism

In [None]:
device_input_fn = input_utils.make_pjit_array_fn(global_mesh, (partition_spec))

In [None]:
ds = dummy_dataset(global_batch_size=flags['global_batch_size'],
                   vocab_size=flags['vocab_size'],
                   num_classes=flags['num_classes'])

dummy_iter = input_utils.split_and_prefetch_to_host_and_devices(
    iterator=iter(ds),
    split_fn=lambda xs: {'host': xs[0], 'device': xs[1]},
    host_input_fn=input_utils.enqueue_prefetch(
        enqueue_fn=functools.partial(
            tpu_embedding_layer.enqueue, is_training=flags['is_training'])),
    device_input_fn=device_input_fn,
    buffer_size=2)

#### Training Loop

In [None]:
# Create TrainState
mlp_model = MLPLayers(
    hidden_dim=flags['hidden_layer_dimension'],
    num_hidden_layers=flags['num_hidden_layers'],
    dropout=flags['dropout'],
    num_classes=flags['num_classes'])

init_params = mlp_model.init(
    jax.random.PRNGKey(123),
    jnp.ones((batch_size_per_device, flags['embedding_dimension'])))
tx = optax.adagrad(learning_rate=flags['learning_rate'])

train_state = TrainState.create(apply_fn=mlp_model.apply, params=init_params,
                                tx=tx)

##### Build Train/Eval Step

In [None]:
def build_step(embedding_layer: jte.TPUEmbedding,
               train_state: TrainState,
               config_flags: Dict[str, Union[int, float]],
               is_training: bool,
               use_pjit: bool):
  """Build train or eval step using tpu embedding."""

  def forward(inputs):
    embedding_activations = inputs['embedding_actv']
    params = inputs['params']
    logits = train_state.apply_fn(params, embedding_activations['watches'])
    one_hot_targets = compute_one_hot_targets(
        inputs['watches_targets'],
        num_classes=config_flags['num_classes'],
        on_value=1.0 / config_flags['num_targets'])
    loss = categorical_cross_entropy_loss(logits, one_hot_targets)
    loss = jnp.sum(loss, axis=0) * (1.0 / config_flags['global_batch_size'])
    return loss

  def step_fn(train_state, watches_targets):
    embedding_activation = embedding_layer.dequeue()
    inputs = {
        'embedding_actv': embedding_activation,
        'params': train_state.params,
        'watches_targets': watches_targets,
    }
    if is_training:
      loss, grads = jax.value_and_grad(forward)(inputs)
      embedding_grads, params_grads = grads['embedding_actv'], grads['params']
      if not use_pjit:
          params_grads = jax.lax.pmean(params_grads, axis_name='devices')
          loss = jax.lax.pmean(loss, axis_name='devices')
      train_state = train_state.apply_gradients(grads=params_grads)
      embedding_layer.apply_gradients(embedding_grads)
    else:
      loss = forward(inputs)
    return loss, train_state

  return step_fn

##### Run with pjit

In [None]:
from jax.experimental import pjit

train_step_fn = build_step(
    embedding_layer=tpu_embedding_layer,
    train_state=train_state,
    config_flags=flags,
    is_training=flags['is_training'],
    use_pjit=True)

num_steps = 10

###### Model Replicated

In [None]:
with global_mesh:
  # Replicated TrainState.
  replicated_train_state = pjit.pjit(
          lambda x: x,
          in_shardings=None,
          out_shardings=None,
          keep_unused=True)(train_state)

  for step in range(num_steps):
    inputs = next(dummy_iter)
    loss, replicated_train_state = pjit.pjit(
        train_step_fn,
        in_shardings=(None, PartitionSpec('x',)),
        out_shardings=(None, None),
        keep_unused=True)(
            replicated_train_state, inputs['device']['watches_target'])
    print('train_step = ', step, 'loss = ', loss)

train_step =  0 loss =  2.7713807
train_step =  1 loss =  2.7181337
train_step =  2 loss =  2.686849
train_step =  3 loss =  2.6700463
train_step =  4 loss =  2.659828
train_step =  5 loss =  2.6517568
train_step =  6 loss =  2.6445704
train_step =  7 loss =  2.6370976
train_step =  8 loss =  2.6295438
train_step =  9 loss =  2.621694


###### Model Parallelism

* Prepare axis resources for train state sharding along devices

In [None]:
import optax
from flax.core import scope as flax_scope

params_resources = flax_scope.FrozenDict({
    'params': {
        'Dense_0': {
            'kernel': PartitionSpec('x', None),
            'bias': PartitionSpec('x',),
        },
        'Dense_1': {
            'kernel': PartitionSpec('x', None),
            'bias': PartitionSpec('x',),
        },
    },
})

sharded_axis_resources = TrainState(
    step=PartitionSpec(), apply_fn=train_state.apply_fn,  
    params=params_resources, 
    tx=train_state.tx, 
    opt_state=(
        optax.ScaleByRssState(
            sum_of_squares=params_resources), optax.EmptyState()))

In [None]:
with global_mesh:
  # Replicated TrainState.
  sharded_train_state = pjit.pjit(
          lambda x: x,
          in_shardings=None,
          out_shardings=sharded_axis_resources,
          keep_unused=True)(train_state)

  for step in range(num_steps):
    inputs = next(dummy_iter)
    loss, sharded_train_state = pjit.pjit(
        train_step_fn,
        in_shardings=(sharded_axis_resources, PartitionSpec('x',)),
        out_shardings=(None, sharded_axis_resources), 
        keep_unused=True)(
            sharded_train_state, inputs['device']['watches_target'])
    print('train_step = ', step, 'loss = ', loss)

train_step =  0 loss =  2.7485476
train_step =  1 loss =  2.6802473
train_step =  2 loss =  2.652364
train_step =  3 loss =  2.6370351
train_step =  4 loss =  2.6250236
train_step =  5 loss =  2.6136596
train_step =  6 loss =  2.6020913
train_step =  7 loss =  2.5901182
train_step =  8 loss =  2.5761333
train_step =  9 loss =  2.559519
