## Training Transformers on TPUs with the API

In this example, we will demonstrate code to train Transformers on a TPU using MEAD/Baseline in TensorFlow.  The basic outline of the program is based on the API example [pretrain-tlm-tf](https://github.com/dpressel/mead-baseline/blob/master/api-examples/pretrain-tlm-tf.py).

The data for this sample is a preprocessed version of [wikitext-2](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/) available from a GCP bucket. It was preprocessed from the original data using the API example [preproc-tlm](https://github.com/dpressel/mead-baseline/blob/master/api-examples/pretrain-tlm-tf.py) with command-line args specified to generate [TFRecords](https://www.tensorflow.org/tutorials/load_data/tfrecord).  To access them, we need to start by authenticating our colab user.

In [0]:
from google.colab import auth
auth.authenticate_user()

!gsutil ls gs://lm-sample


gs://lm-sample/wt2/


The meta-data we need to process this example is publicly available on dropbox, and was previously processed with [preproc-tlm](https://github.com/dpressel/mead-baseline/blob/master/api-examples/pretrain-tlm-tf.py) (to generate the `YAML` md files), and [fastBPE](https://github.com/glample/fastBPE) was run to generate the vocab and codes files.


In [0]:

!wget https://www.dropbox.com/s/yaqs2dx51kc4sb2/wt2-md.tar.gz?dl=1
!tar -xzf wt2-md.tar.gz?dl=1
!rm wt2-md.tar.gz?dl=1

--2020-06-01 16:37:32--  https://www.dropbox.com/s/yaqs2dx51kc4sb2/wt2-md.tar.gz?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.1, 2620:100:6018:1::a27d:301
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/yaqs2dx51kc4sb2/wt2-md.tar.gz [following]
--2020-06-01 16:37:32--  https://www.dropbox.com/s/dl/yaqs2dx51kc4sb2/wt2-md.tar.gz
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucf5f62194856a58ee3953a2b5b2.dl.dropboxusercontent.com/cd/0/get/A40TQap6UgNNNi019U71VtWQf9wOnqlmBSlXi-R1MFqWO6ioTXYo3CiR7OpSXQbc4ya3tPxELhdwjxqaNYLoG0AWcLzAy5bZ3d7BZwfl7i5ZZHzkMcCcDbIzLXQ0ebTwfIs/file?dl=1# [following]
--2020-06-01 16:37:32--  https://ucf5f62194856a58ee3953a2b5b2.dl.dropboxusercontent.com/cd/0/get/A40TQap6UgNNNi019U71VtWQf9wOnqlmBSlXi-R1MFqWO6ioTXYo3CiR7OpSXQbc4ya3tPxELhdwjxqaNYLoG0AWcLzAy5bZ3d7BZwfl7i5ZZ

In [0]:
!ls


adc.json  mlm-bpe-1871	sample_data  wt2


To run our example, we need to install TensorFlow, [fastBPE](https://github.com/glample/fastBPE), and MEAD/Baseline with [TensorFlow addons](https://www.tensorflow.org/addons/overview).  If you get an error at the end of this command, run it a second time.

In [0]:
!pip install tensorflow
!pip install fastBPE
!pip install mead-baseline[tf2]



To run our example, we will need to import the `BPEVectorizer1D` which is reponsible for vectorizing text to BPE form, a few utilities from 8-mile including the entire optimizer base and TF packages, as well as the language models we will be using from the `baseline.tf.lm` package.

In [0]:
import time
import os
from argparse import ArgumentParser
import math
from typing import Tuple
import baseline
from eight_mile.utils import str2bool, write_json
import baseline.tf.embeddings
import baseline.embeddings
from baseline.vectorizers import BPEVectorizer1D
from eight_mile.utils import Average, get_num_gpus_multiworker, read_yaml
from eight_mile.optz import *
from eight_mile.tf.optz import *
from baseline.tf.lm import SET_TRAIN_FLAG, TransformerLanguageModel, TransformerMaskedLanguageModel
import tensorflow as tf
import glob
import json

logger = logging.getLogger("mead-transformers-tpu")

Here we will define the loss as a function object -- a class with an overloaded `__call__()` function.  We will instantiate using the normal constructor (which takes the BPE `vocab_size` and the context window size), but when its called during optimization, the instance name is called with parens (containing the `model`, `features` and `labels`) as though it was a normal function.

In [0]:

class Loss:
    def __init__(self, vocab_size, nctx):
        self.vocab_size = vocab_size
        self.nctx = nctx

    def __call__(self, model, features, labels):
        logits, _ = model(features, None)
        loss_mask = tf.cast(labels != 0, tf.float32)
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        losses = losses * loss_mask
        losses = tf.reduce_sum(losses)
        non_zero = tf.reduce_sum(loss_mask)
        losses /= non_zero
        return losses

Our preprocessed data files are written in shards of TFRecords (of size ~100MB each) with the feature `x` representing the integer values of the input BPE tokens and `y` representing the target integer values to recover during masked language modeling).

Here is an example of what these records would look like as a JSON object:

```json
{
  "x": [4, 10, 99, 1926, 21, 128, 11, 106, 5, 13288, 33, 7, 5409, 565, 399, 6, 26, 10, 20, 1191, 153, 283, 13, 39, 399, 57, 6, 10, 49, 1616, 5, 63, 1822, 11, 106, 58, 5, 10, 117, 183, 195, 283, 18, 7, 295, 11, 59, 6, 5, 5, 99, 1238, 13, 484, 5, 11723, 852, 18, 2652, 8, 330, 292, 5, 50, 5, 88, 11, 6, 168, 13, 75, 87, 5, 830, 7820, 3012, 18, 834, 9559, 13, 135, 1175, 19, 32, 49, 134, 13, 11, 53, 15, 7, 5939, 1069, 30, 2, 86, 1238, 5, 382, 5129, 8, 4882, 4211, 3383, 42, 5, 19213, 6, 7, 3594, 5, 33, 5129, 5, 221, 17, 140, 8, 38, 5, 49, 127, 19, 7, 11963, 39, 2549, 131, 6, 2, 100, 1238, 13, 39, 28, 18, 2562, 18, 4428, 14327, 7, 44, 12, 1610, 5, 979, 2021, 194, 221, 51, 6, 26, 10, 5, 8432, 7, 3594, 44, 379, 3010, 5, 495, 33, 19, 71, 6, 5, 47, 1713, 174, 16, 644, 15, 231, 7, 3442, 3378, 6, 5, 5, 37, 10, 52, 7, 632, 5, 13, 8, 13, 356, 52, 12, 5, 5, 5, 194, 7, 700, 15, 7, 5, 8, 9311, 28, 1090, 5, 84, 657, 11858, 14043, 8, 9, 294, 27, 28, 37, 63, 6772, 42, 2697, 37, 5, 2488, 19, 31, 347, 21, 7, 971, 41, 5, 2299, 5, 5, 20151, 1645, 37, 11, 66, 294, 5, 18, 7, 5, 15, 7, 5, 17, 233, 19, 92, 56, 13, 5, 86, 5],
  "y": [0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 36, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 17, 0, 11, 0, 0, 0, 0, 0, 0, 0, 18408, 0, 0, 0, 0, 834, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 15380, 0, 0, 0, 0, 20, 0, 0, 194, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 83, 127, 0, 0, 0, 0, 0, 6605, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 118, 0, 0, 0, 7, 0, 15, 0, 0, 0, 0, 0, 0, 878, 19, 416, 0, 0, 0, 0, 0, 342, 0, 46, 0, 0, 6002, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3442, 0, 0, 0, 0, 0, 0, 0, 0, 2839, 0, 6, 83, 72, 0, 37, 0, 0, 0, 28, 0, 0, 668, 0, 0, 3378, 0, 0, 0, 0, 0, 0, 30, 0, 644]
}
```

To read the TFRecords, we need a descriptor of these features, and we need to use the [tf.io.parse_single_example()](https://www.tensorflow.org/api_docs/python/tf/io/parse_single_example) function, which we will apply to each record.

These files are stored on a GCP bucket, so when we call `get_dataset()` below, we will `glob` the bucket for `*.tfrecord` files then map each record retreival to a pipeline that includes parsing the example.

In [0]:


feature_description = {
    'x': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True, default_value=0),
    'y': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True, default_value=0),
}


def _parse_tf_record(example_proto):
    record = tf.io.parse_single_example(example_proto, feature_description)
    return record['x'], record['y']


def get_dataset(directory):
    pattern = os.path.join(directory, f'*.tfrecord')
    files = tf.io.gfile.glob(pattern)
    print(files)
    ds = tf.data.TFRecordDataset(files).map(_parse_tf_record)
    return ds

Our example will use the `tf.distribute` library, which allows us to provide a [tf.distribute.Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy) for handling distribution details. We just need to create the right sub-class, in this case a [TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy), which we will need to initialize with a [TPUClusterResolver](https://www.tensorflow.org/api_docs/python/tf/distribute/cluster_resolver/TPUClusterResolver) that takes in an address for the TPU.  Because this example is in colab, we are going to end up calling this function with an empty address and it will use the environment variable `COLAB_TPU_ADDR` to resolve our TPU address.

In [0]:
def create_distribute_strategy(strategy_name, endpoint=None):
    if strategy_name == 'tpu':
        if endpoint is None:
            endpoint = 'grpc://' + os.environ['COLAB_TPU_ADDR']
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=endpoint)
        tf.config.experimental_connect_to_cluster(resolver)
        # This is the TPU initialization code that has to be at the beginning.
        tf.tpu.experimental.initialize_tpu_system(resolver)
        print("All devices: ", tf.config.list_logical_devices('TPU'))
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
    elif strategy_name == 'mirror':
        num_gpus = get_num_gpus_multiworker()
        devices = ['/device:GPU:{}'.format(i) for i in range(num_gpus)]
        strategy = tf.distribute.MirroredStrategy(devices)
    else:
        raise Exception(f"Unsupported strategy {strategy_name}")
    return strategy

Now that we have set up a lot of the boilerplate we can go ahead and create our
1. `tf.distribute.Strategy` using the function above
2. `baseline.Vectorizer` which we will use to initialize the vocabulary of our model
3. `baseline.embeddings` whch we will use to initalize a lookup table that projects from our vocabulary size to a hidden unit size (`MLM_MODEL_SZ` below)
4. `baseline.tf.lm.TransformerMaskedLanguageModel` which is the model we will be training.  It is built internally on `8 mile` by using an [eight_mile.tf.layers.TransformerEncoderStack](https://github.com/dpressel/mead-baseline/blob/master/layers/eight_mile/tf/layers.py), which composes a stack of `Transformer` encoders with interleaved multi-head attention and FFN sub-stacks.

The `TransformerMaskedLanguageModel` is a type of `class AbstractGeneratorModel(LanguageModelBase)`, which defines the abstract phases to create a `LanguageModel`:

```python
    def create_layers(self, embeddings, **kwargs):
        self.embeddings = self.init_embed(embeddings, **kwargs)
        self.embeddings_proj = self.init_embeddings_proj(**kwargs)
        self.generator = self.init_generate(**kwargs)
        self.output_layer = self.init_output(embeddings, **kwargs)
```

This gets sub-classed by the `TransformerLanguageModel` to provide a `Transformer`-based `generator` object:

```python
    def init_generate(self, **kwargs):
        pdrop = float(kwargs.get('dropout', 0.1))
        layers = kwargs.get('layers', kwargs.get('num_layers', 1))
        d_model = int(kwargs.get('d_model', kwargs.get('hsz')))
        num_heads = kwargs.get('num_heads', 4)
        d_ff = int(kwargs.get('d_ff', 4 * d_model))
        rpr_k = kwargs.get('rpr_k')
        d_k = kwargs.get('d_k')
        scale = bool(kwargs.get('scale', True))
        activation = kwargs.get('activation', 'gelu')
        layer_norm_eps = kwargs.get('layer_norm_eps', 1e-12)
        layer_norms_after = kwargs.get('layer_norms_after', False)
        return TransformerEncoderStack(num_heads, d_model=d_model, pdrop=pdrop, scale=scale,
                                       layers=layers, d_ff=d_ff, rpr_k=rpr_k, d_k=d_k,
                                       activation=activation, layer_norm_eps=layer_norm_eps,
                                       layer_norms_after=layer_norms_after)
```

In [0]:


MLM_MODEL_SZ = 512
MLM_FFN_SZ = 4 * MLM_MODEL_SZ
MLM_CONTEXT_LENGTH = 256
MLM_REL_POS_REPR = 8
MLM_DROPOUT = 0.1
MLM_NUM_HEADS = 8
MLM_NUM_LAYERS = 8
SET_TRAIN_FLAG(True)

SUBWORD_MODEL_FILE = './wt2/wiki.train.bpe.50k.codes'
SUBWORD_VOCAB_FILE = './wt2/wiki.train.bpe.50k.vocab'

logging.basicConfig(level=logging.INFO)

strategy = create_distribute_strategy("tpu")
num_replicas = strategy.num_replicas_in_sync
logger.info(f"Using {num_replicas} replicas in this job.")
vectorizer = BPEVectorizer1D(model_file=SUBWORD_MODEL_FILE,
                             vocab_file=SUBWORD_VOCAB_FILE,
                             mxlen=MLM_MODEL_SZ)
vocab = {'x': vectorizer.vocab}
preproc_data = baseline.embeddings.load_embeddings('x',
                                                   dsz=MLM_MODEL_SZ,
                                                   known_vocab=vocab['x'],
                                                   preserve_vocab_indices=True,
                                                   embed_type="default")
vocabs = preproc_data['vocab']
vocab_size = max(vocabs.values())
embeddings = {'x': preproc_data['embeddings']}

model = TransformerMaskedLanguageModel.create(embeddings,
                                              hsz=MLM_MODEL_SZ,
                                              d_ff=MLM_FFN_SZ,
                                              tie_weights=True,
                                              dropout=MLM_DROPOUT,
                                              num_heads=MLM_NUM_HEADS,
                                              layers=MLM_NUM_LAYERS,
                                              rpr_k=MLM_REL_POS_REPR,
                                              src_keys=['x'], tgt_key='x')


INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0


INFO:tensorflow:Initializing the TPU system: grpc://10.101.52.138:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.101.52.138:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU')]
INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:mead-transformers-tpu:Using 8 replicas in this job.


We have defined our model and our loss function, as well as the initial step of our data pipeline using `TFRecordDataset`.  Now we will provide functions that will distribute our reader over each replica using the `strategy.experimental_distribute_datasets_from_function` API call, and batches the dataset into our per-replica dataset size. 

In [0]:
BATCH_SIZE = 8 * 20
TRAIN_BUCKET = "gs://lm-sample/wt2/train"
VALID_BUCKET = "gs://lm-sample/wt2/valid"
MD_TRAIN_FILE = "./wt2/train/md.yml"
MD_TEST_FILE = "./wt2/valid/md.yml"
def get_num_samples(f):
    yml = read_yaml(f)
    return yml['num_samples']

def dataset_train_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(BATCH_SIZE)
    ds = get_dataset(TRAIN_BUCKET).batch(batch_size)
    return ds.shard(
        input_context.num_input_pipelines, input_context.input_pipeline_id
    )
train_loader = strategy.experimental_distribute_datasets_from_function(dataset_train_fn)

def dataset_test_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(BATCH_SIZE)
    ds = get_dataset(VALID_BUCKET).batch(batch_size)
    return ds.shard(
        input_context.num_input_pipelines, input_context.input_pipeline_id
    )
valid_loader = strategy.experimental_distribute_datasets_from_function(dataset_test_fn)

num_train_samples = get_num_samples(MD_TRAIN_FILE)
num_valid_samples = get_num_samples(MD_TEST_FILE)

Now we can create our loss function, and set some variables so we can get periodic updates while we train.

In [0]:
loss_function = Loss(vocab_size, MLM_CONTEXT_LENGTH)
steps_per_epoch = num_train_samples // BATCH_SIZE
steps_per_valid_epoch = num_valid_samples // BATCH_SIZE
update_on = steps_per_epoch // 2
report_on = update_on // 4
logger.info(f"Steps per epoch: {steps_per_epoch}. Update every {update_on} steps.")


INFO:mead-transformers-tpu:Steps per epoch: 56. Update every 28 steps.


The `8 mile` API provides an `EagerOptimizer` which takes in a loss function (or function object in our case) and applies it for every step of training to provide a per-replica loss.  We are going to create a learning regimen that starts with a linear warmup to the target learning rate over `WARMUP_STEPS` composed with learning rate scheduler that provides a cosine decay.

We will specify that our model is to be trained with [adamw](https://www.fast.ai/2018/07/02/adam-weight-decay/) (we use the `tensorflow_addons` implementation underneath) with weight decay given by `WD` below:

In [0]:
LR = 1.0e-4 * 8
EPOCHS = 10
GRAD_CLIP = 1.0
OPTIM = "adamw"
WD = 1.0e-5
WARMUP_STEPS = 10000
lr_decay = CosineDecaySchedulerTensorFlow(steps_per_epoch)
linear_warmup = WarmupLinearSchedulerTensorFlow(WARMUP_STEPS, lr=LR)
lr_sched = CompositeLRSchedulerTensorFlow(linear_warmup, lr_decay, lr=LR)

optimizer = EagerOptimizer(loss_function, global_step=1, lr=LR, optim=OPTIM,
                           learning_rate_decay_fn=lr_sched,
                           weight_decay=WD, clip=GRAD_CLIP)

INFO:mead.layers:adamw(eta=0.000800 beta1=0.900000, beta2=0.999000, eps=0.000000, wd=0.000010)
INFO:mead.layers:clip gradients at 1.0


We are going to define 2 autograph compiled functions, one for distributed training and one for distributed testing.  The implementations proxy to an underlying per-replica function which optimize and provide back the per-replica losses, which are accumulated in the distributed function.  These functions use the `strategy.experimental_run_v2` function to call the per-replica versions, and `strategy.reduce` to sum the results.

In [0]:

def _replicated_train_step(inputs):
    x, y = inputs
    per_replica_loss = optimizer.update(model, {'x': x}, y, num_replicas)
    return per_replica_loss

@tf.function
def _distributed_train_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
    per_replica_loss = strategy.experimental_run_v2(_replicated_train_step, args=(inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

valid_loss_function = Loss(vocab_size, MLM_CONTEXT_LENGTH)
def _replicated_test_step(inputs):
    x, y = inputs
    per_replica_loss = valid_loss_function(model, {'x': x}, y) / num_replicas
    return per_replica_loss

@tf.function
def _distributed_test_step(inputs: Tuple[tf.Tensor, tf.Tensor]):
    per_replica_loss = strategy.experimental_run_v2(_replicated_test_step, args=(inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)

We have finally set up all the boilerplate and we can run a normal training loop.  The only difference for a distributed program within the training loop is that we call the operations within a `strategy.scope()`.

In [0]:
steps = 1
start_epoch = 0

with strategy.scope():

    SET_TRAIN_FLAG(True)
    for epoch in range(start_epoch, EPOCHS):
        avg_loss = Average('average_train_loss')
        metrics = {}
        start = time.time()
        train_iter = iter(train_loader)
        for i in range(steps_per_epoch):
            steps += 1
            loss = _distributed_train_step(next(train_iter))
            avg_loss.update(loss.numpy().item())
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)
            if (i + 1) % update_on == 0:
                elapsed = (time.time() - start)/60
                ##print(avg_loss.avg, math.exp(avg_loss.avg))
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i/elapsed)

        # How much time elapsed in minutes
        elapsed = (time.time() - start)/60
        train_token_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        train_token_ppl = math.exp(train_token_loss)
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_token_loss
        metrics['average_train_token_ppl'] = train_token_ppl
        avg_valid_loss = Average('average_valid_loss')
        start = time.time()
        SET_TRAIN_FLAG(False)
        valid_iter = iter(valid_loader)
        for i in range(steps_per_valid_epoch):
            valid_loss = _distributed_test_step(next(valid_iter))
            avg_valid_loss.update(valid_loss.numpy().item())

        valid_token_loss = avg_valid_loss.avg
        valid_token_ppl = math.exp(valid_token_loss)

        elapsed = (time.time() - start)/60
        metrics['valid_elapsed_min'] = elapsed
        metrics['average_valid_loss'] = valid_token_loss
        metrics['average_valid_token_ppl'] = valid_token_ppl
        print(json.dumps(metrics, indent=4, sort_keys=True))

['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 7.089441 (7.075352)
INFO:root:average_train_loss 6.983522 (7.072307)
INFO:root:average_train_loss 7.122105 (7.080230)
INFO:root:average_train_loss 6.993566 (7.073015)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 31.587442 steps/min
INFO:root:average_train_loss 7.015446 (7.071518)
INFO:root:average_train_loss 7.071340 (7.078360)
INFO:root:average_train_loss 7.175150 (7.081775)
INFO:root:average_train_loss 7.149497 (7.087573)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 56.690000 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.087573375020709,
    "average_train_token_ppl": 1196.9996043469598,
    "average_valid_loss": 6.938916206359863,
    "average_valid_token_ppl": 1031.6515115010525,
    "train_elapsed_min": 0.9702444513638814,
    "valid_elapsed_min": 0.07268238464991252
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 6.971631 (6.975532)
INFO:root:average_train_loss 6.896413 (6.974734)
INFO:root:average_train_loss 7.057155 (6.988991)
INFO:root:average_train_loss 6.938303 (6.988653)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 217.506247 steps/min
INFO:root:average_train_loss 6.951414 (6.993201)
INFO:root:average_train_loss 7.001421 (6.999779)
INFO:root:average_train_loss 7.125483 (7.006451)
INFO:root:average_train_loss 7.086274 (7.014928)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 228.429860 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.014928187642779,
    "average_train_token_ppl": 1113.1267070885322,
    "average_valid_loss": 6.9435882568359375,
    "average_valid_token_ppl": 1036.4827164652493,
    "train_elapsed_min": 0.2408829132715861,
    "valid_elapsed_min": 0.01825094223022461
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 6.949949 (6.972112)
INFO:root:average_train_loss 6.888988 (6.968611)
INFO:root:average_train_loss 7.066749 (6.981960)
INFO:root:average_train_loss 6.911725 (6.984062)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 218.363423 steps/min
INFO:root:average_train_loss 6.959283 (6.988287)
INFO:root:average_train_loss 6.978216 (6.993257)
INFO:root:average_train_loss 7.109393 (6.997729)
INFO:root:average_train_loss 7.081411 (7.006621)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 229.776564 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.006621216024671,
    "average_train_token_ppl": 1103.9182950867048,
    "average_valid_loss": 6.945255374908447,
    "average_valid_token_ppl": 1038.2120966736757,
    "train_elapsed_min": 0.23940510749816896,
    "valid_elapsed_min": 0.018021090825398763
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 6.937329 (6.982219)
INFO:root:average_train_loss 6.914935 (6.972902)
INFO:root:average_train_loss 7.038431 (6.983937)
INFO:root:average_train_loss 6.900294 (6.978404)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 219.663637 steps/min
INFO:root:average_train_loss 6.969196 (6.983082)
INFO:root:average_train_loss 7.082540 (7.005794)
INFO:root:average_train_loss 7.117290 (7.015462)
INFO:root:average_train_loss 7.050060 (7.019021)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 230.590045 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.019020863941738,
    "average_train_token_ppl": 1117.6917095470744,
    "average_valid_loss": 6.979404735565185,
    "average_valid_token_ppl": 1074.2786967269392,
    "train_elapsed_min": 0.23855806191762288,
    "valid_elapsed_min": 0.01808029810587565
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 6.999094 (7.022191)
INFO:root:average_train_loss 6.926431 (7.010803)
INFO:root:average_train_loss 7.144818 (7.032847)
INFO:root:average_train_loss 6.982699 (7.044690)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 217.842766 steps/min
INFO:root:average_train_loss 6.976415 (7.045303)
INFO:root:average_train_loss 7.050766 (7.052213)
INFO:root:average_train_loss 7.113155 (7.053249)
INFO:root:average_train_loss 7.044514 (7.052607)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 228.551451 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.052606642246246,
    "average_train_token_ppl": 1155.867753089225,
    "average_valid_loss": 6.956824016571045,
    "average_valid_token_ppl": 1050.2925428358872,
    "train_elapsed_min": 0.24068913062413533,
    "valid_elapsed_min": 0.01954260269800822
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 7.016996 (7.027875)
INFO:root:average_train_loss 6.958185 (7.035856)
INFO:root:average_train_loss 7.132082 (7.055464)
INFO:root:average_train_loss 6.961725 (7.056203)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 216.214068 steps/min
INFO:root:average_train_loss 7.000763 (7.054603)
INFO:root:average_train_loss 7.048105 (7.062166)
INFO:root:average_train_loss 7.126616 (7.063251)
INFO:root:average_train_loss 7.084837 (7.064421)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 229.089940 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.064421279089792,
    "average_train_token_ppl": 1169.6049007445342,
    "average_valid_loss": 6.942471885681153,
    "average_valid_token_ppl": 1035.3262626940764,
    "train_elapsed_min": 0.24015008608500163,
    "valid_elapsed_min": 0.01828513542811076
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 6.998410 (7.023473)
INFO:root:average_train_loss 6.926228 (7.013489)
INFO:root:average_train_loss 7.112935 (7.030364)
INFO:root:average_train_loss 6.970592 (7.036053)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 215.566643 steps/min
INFO:root:average_train_loss 6.992733 (7.038232)
INFO:root:average_train_loss 7.018197 (7.044628)
INFO:root:average_train_loss 7.112901 (7.077833)
INFO:root:average_train_loss 7.080790 (7.075710)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 227.145991 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.0757095984050205,
    "average_train_token_ppl": 1182.8825746549512,
    "average_valid_loss": 6.940936183929443,
    "average_valid_token_ppl": 1033.7375305604623,
    "train_elapsed_min": 0.24223616123199462,
    "valid_elapsed_min": 0.017551032702128093
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 7.021813 (7.036692)
INFO:root:average_train_loss 6.927490 (7.029205)
INFO:root:average_train_loss 7.088665 (7.038347)
INFO:root:average_train_loss 6.951767 (7.034799)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 218.921018 steps/min
INFO:root:average_train_loss 6.998722 (7.034758)
INFO:root:average_train_loss 7.028143 (7.042662)
INFO:root:average_train_loss 7.117557 (7.044002)
INFO:root:average_train_loss 7.061752 (7.044619)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 231.270327 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.044618563992636,
    "average_train_token_ppl": 1146.6713706393016,
    "average_valid_loss": 6.9357068061828615,
    "average_valid_token_ppl": 1028.3458364119842,
    "train_elapsed_min": 0.23789680004119873,
    "valid_elapsed_min": 0.01736583709716797
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 7.008987 (7.018630)
INFO:root:average_train_loss 6.920163 (7.015868)
INFO:root:average_train_loss 7.086972 (7.028645)
INFO:root:average_train_loss 6.942828 (7.026551)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 220.335643 steps/min
INFO:root:average_train_loss 6.978358 (7.025822)
INFO:root:average_train_loss 7.010882 (7.032273)
INFO:root:average_train_loss 7.114468 (7.033459)
INFO:root:average_train_loss 7.066011 (7.035308)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 230.143539 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.035307569163186,
    "average_train_token_ppl": 1136.0442706486592,
    "average_valid_loss": 6.935028457641602,
    "average_valid_token_ppl": 1027.6484960605894,
    "train_elapsed_min": 0.23902015288670858,
    "valid_elapsed_min": 0.01749055782953898
}
['gs://lm-sample/wt2/train/train-1.tfrecord']


INFO:root:average_train_loss 7.004523 (7.016095)
INFO:root:average_train_loss 6.914292 (7.010748)
INFO:root:average_train_loss 7.080095 (7.021781)
INFO:root:average_train_loss 6.939723 (7.020098)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 217.255786 steps/min
INFO:root:average_train_loss 6.974094 (7.019442)
INFO:root:average_train_loss 7.014464 (7.026440)
INFO:root:average_train_loss 7.108461 (7.027479)
INFO:root:average_train_loss 7.059165 (7.029342)
INFO:root:elapsed time this epoch 0 min
INFO:root:elapsed step time 227.802929 steps/min


['gs://lm-sample/wt2/valid/valid-1.tfrecord']
{
    "average_train_loss": 7.029342080865588,
    "average_train_token_ppl": 1129.287385942992,
    "average_valid_loss": 6.934982681274414,
    "average_valid_token_ppl": 1027.6014551223839,
    "train_elapsed_min": 0.24149733384450275,
    "valid_elapsed_min": 0.01960286299387614
}


### Wrap-Up and Next Steps

In this example, we used the MEAD/Baseline API code to train a Transformer Language Model on TPUs.  To keep things very simple, we did not write out any checkpoints or any training logs.  To see how you would do this with the `CheckpointManager`, see the [full API example](https://github.com/dpressel/mead-baseline/blob/master/api-examples/pretrain-tlm-tf.py).

Also, as this is just a sample, we used a very small dataset [wikitext-2](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/).  To make a compelling MLM, you would want to use a much larger dataset and probably spend a bit more time with the hyper-parameters.