# Pre-training BERT from scratch with cloud TPU

In this experiment, we will be pre-training a state-of-the-art Natural Language Understanding model [BERT](https://arxiv.org/abs/1810.04805.) on MHC-ligand data using Google Cloud infrastructure.

MIT License

Copyright (c) [2019] [Antyukhov Denis Olegovich]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

## Step 1: setting up training environment

In [0]:
!pip install sentencepiece
!git clone https://github.com/google-research/bert

import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm

from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar

sys.path.append("bert")

from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder

auth.authenticate_user()
  
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)
tf.logging.set_verbosity(tf.logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 3.4MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.85
Cloning into 'bert'...
remote: Enumerating objects: 336, done.[K
remote: Total 336 (delta 0), reused 0 (delta 0), pack-reused 336[K
Receiving objects: 100% (336/336), 291.40 KiB | 3.99 MiB/s, done.
Resolving deltas: 100% (184/184), done.

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



2019-12-19 09:22:10,048 :  Using TPU runtime
2019-12-19 09:22:10,050 :  TPU address is grpc://10.86.137.10:8470


## Step 2: getting the data

Upload data set consisting of concatenation of c001, c002, and c003. This data set consists of peptides of length 8-15 amino acids. Every peptide is transformed to contain white spaces between each amino acid to resemble words.

In [0]:
!tail sample_data/proc_dataset.txt

y y y s i k d i
y y y s p e q s k p d h l v
y y y t a k l s s r i d d
y y y t k e e q f
y y y t k e e q f
y y y v f d a i e q e
y y y v g f s y m m m r
y y y v t p n s d t a k y
y y y y q p p r v
y y y y v i d r p l y q


## Step 3: preprocessing text

Even though our data set does not contain punctuations and amino acids most often are represented in capital letters we proceed with preprocessing as if we were dealing with regular text to ensure compatibility.

In [0]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punktuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

Apply normalization to the whole dataset.

In [0]:
PRC_DATA_FPATH = "sample_data/proc_dataset.txt" #@param {type: "string"}

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



## Step 4: building the vocabulary

In this step we tokenize our words. The BERT paper uses WordPiece tokenizer, which is not available opensource. Therefore we will be using SentencePiece and modify the output to approach WordPiece.

In [0]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")
MODEL_PREFIX = "tokenizer"
VOC_SIZE = 42
SUBSAMPLE_SIZE = 12800000
NUM_PLACEHOLDERS = 20
SENTENCE_SIZE = 30

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE)

spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

SentencePiece has created two files: tokenizer.model and tokenizer.vocab. Below we see the vocabulary:

In [0]:
!cat tokenizer.vocab

<unk>	0
▁l	-2.35302
▁a	-2.52203
▁g	-2.67796
▁v	-2.69782
▁e	-2.71504
▁s	-2.73995
▁i	-2.85276
▁k	-2.87146
▁r	-2.91556
▁d	-2.93561
▁t	-2.94742
▁p	-3.06394
▁n	-3.23505
▁q	-3.25101
▁f	-3.27259
▁y	-3.53229
▁h	-3.78897
▁m	-3.85723
▁c	-4.34252
▁	-4.4067
▁w	-4.54129
l	-6.73505
a	-6.90406
g	-7.06
v	-7.07986
e	-7.09708
s	-7.12199
i	-7.2348
k	-7.2535
r	-7.29761
d	-7.31766
t	-7.32947
p	-7.44598
n	-7.6171
q	-7.63306
f	-7.65465
y	-7.91435
h	-8.17104
m	-8.2393
c	-8.72462
w	-8.92341


In [0]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 41
Sample tokens: ['▁f', 'i', '▁g', 'h', 't', '▁c', '▁h', '▁l', 'r', 'n']


Modifying the output of SentencePiece to resemble WordPiece.

In [0]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

In [0]:
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

We also add some special control symbols which are required by the BERT architecture. By convention, we put those at the beginning of the vocabulary.

In [0]:
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

It is also custum to append placeholder tokens to the vocabulary. Those are useful if one wishes to update the pre-trained model with new, task-specific tokens. 

In [0]:
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

46


Finally, we write the obtained vocabulary to file.

In [0]:
VOC_FNAME = "vocab.txt"

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

## Step 5: generating pre-training data

In [0]:
!mkdir ./shards
!split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
!ls ./shards/

shard_0000  shard_0003	shard_0006  shard_0009	shard_0012  shard_0015
shard_0001  shard_0004	shard_0007  shard_0010	shard_0013
shard_0002  shard_0005	shard_0008  shard_0011	shard_0014


In [0]:
!wc -l ./shards/shard*

  256000 ./shards/shard_0000
  256000 ./shards/shard_0001
  256000 ./shards/shard_0002
  256000 ./shards/shard_0003
  256000 ./shards/shard_0004
  256000 ./shards/shard_0005
  256000 ./shards/shard_0006
  256000 ./shards/shard_0007
  256000 ./shards/shard_0008
  256000 ./shards/shard_0009
  256000 ./shards/shard_0010
  256000 ./shards/shard_0011
  256000 ./shards/shard_0012
  256000 ./shards/shard_0013
  256000 ./shards/shard_0014
   27495 ./shards/shard_0015
 3867495 total


Model-specific parameters.  

In [0]:
MAX_SEQ_LENGTH = 128
MASKED_LM_PROB = 0.15
MAX_PREDICTIONS = 20
DO_LOWER_CASE = True
PROCESSES = 2
PRETRAINING_DIR = "pretraining_data2"

Call *create_pretraining_data.py* script using the  *xargs* command. 

In [0]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

In [0]:
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD



W1217 22:53:07.484345 139755324290944 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W1217 22:53:07.484632 139755324290944 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W1217 22:53:07.484893 139755324290944 module_wrapper.py:139] From /content/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.


W1217 22:53:07.486832 139755324290944 module_wrapper.py:139] From bert/create_pretraining_data.py:444: The name tf.gfile.Glob is deprecated. Please use tf.io.gfile.glob instead.



W1217 22:53:07.492668 140406266410880 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W1217 22:53:07.492921 140406

## Step 6: setting up persistent storage

In [0]:
BUCKET_NAME = "dl-project-nlp-bucket"
MODEL_DIR = "bert_model2"
tf.gfile.MkDir(MODEL_DIR)

if not BUCKET_NAME:
  log.warning("WARNING: BUCKET_NAME is not set. "
              "You will not be able to train the model.")

Below is the sample hyperparameter configuration for BERT-base.

In [0]:
# use this for BERT-base

bert_base_config = {
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": VOC_SIZE
}

with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
  json.dump(bert_base_config, fo, indent=2)
  
with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [0]:
if BUCKET_NAME:
  !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

Copying file://bert_model2/vocab.txt [Content-Type=text/plain]...
Copying file://pretraining_data2/shard_0002.tfrecord [Content-Type=application/octet-stream]...
Copying file://bert_model2/bert_config.json [Content-Type=application/json]...
Copying file://pretraining_data2/shard_0004.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data2/shard_0014.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data2/shard_0003.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data2/shard_0012.tfrecord [Content-Type=application/octet-stream]...
==> NOTE: You are uploading one or more large file(s), which would run
significantly faster if you enable parallel composite uploads. This
feature can be enabled by editing the
"parallel_composite_upload_threshold" value in your .boto
configuration file. However, note that if you do this large files will
be uploaded as `composite objects
<https://cloud.google.com/storage/d

## Step 7: training the model

In [0]:
BUCKET_NAME = "dl-project-nlp-bucket"
MODEL_DIR = "bert_model2"
PRETRAINING_DIR = "pretraining_data2"
VOC_FNAME = "vocab.txt"

# Input data pipeline config
TRAIN_BATCH_SIZE = 128
MAX_PREDICTIONS = 20
MAX_SEQ_LENGTH = 30
MASKED_LM_PROB = 0.15

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000
SAVE_CHECKPOINTS_STEPS = 2500
NUM_TPU_CORES = 8

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)

bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))

2019-12-19 09:22:37,669 :  From /content/bert/modeling.py:93: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.

2019-12-19 09:22:38,435 :  Using checkpoint: gs://dl-project-nlp-bucket/bert_model2/model.ckpt-850000
2019-12-19 09:22:38,436 :  Using 16 data shards


Prepare the training run configuration, build the estimator and input function, power up the bass cannon.

In [0]:
model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

2019-12-19 09:22:45,932 :  Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7fa196c6fd08>) includes params argument, but params are not passed to Estimator.
2019-12-19 09:22:45,934 :  Using config: {'_model_dir': 'gs://dl-project-nlp-bucket/bert_model2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 2500, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.86.137.10:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fa13fc92

Start training

In [0]:
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

2019-12-19 09:22:55,430 :  Querying Tensorflow master (grpc://10.86.137.10:8470) for TPU system metadata.
2019-12-19 09:22:55,446 :  Found TPU system:
2019-12-19 09:22:55,447 :  *** Num TPU Cores: 8
2019-12-19 09:22:55,450 :  *** Num TPU Workers: 1
2019-12-19 09:22:55,452 :  *** Num TPU Cores Per Worker: 8
2019-12-19 09:22:55,453 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 3963823626525284734)
2019-12-19 09:22:55,455 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 13009607125317704663)
2019-12-19 09:22:55,457 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 9266008656464707654)
2019-12-19 09:22:55,459 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 13372683273593967058)
2019-12-19 09:22:55,460 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:

<tensorflow_estimator.python.estimator.tpu.tpu_estimator.TPUEstimator at 0x7fa13fc888d0>

This notebook was guided by https://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379