<a href="https://colab.research.google.com/github/diego-feijo/bertpt/blob/master/Pre_training_ALBERT_from_scratch_with_cloud_TPU_Wikipedia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pre-training ALBERT from Wikipedia Dump

This notebook pre-trains an [ALBERT](https://github.com/google-research/ALBERT) model from Wikipedia dump using free Colab TPU v2.

These are the steps to follow:

1. Setting Up the Environment
2. Download and Prepare Data
3. Extract Raw Text
4. Build SentencePiece Model
5. Generate pre-training data
6. Train the model

As we will be using Colab TPU, it is required a [Google Cloud Storage bucket](https://cloud.google.com/tpu/docs/quickstart).  New users receive [$300 free credit](https://cloud.google.com/free/) for one year to get started with any GCP product. 

After each step, we save persistent data so we can always stop and resume from the last finished step.

**Note** 
The only parameter you *really have to set* is BUCKET_NAME in steps 5 and 6. Everything else has default values which should work for most use-cases.

**Note** 
Pre-training a ALBERT-Base model on a TPU v2 will take about 17 hours. Google Colab is not designed for executing such long-running jobs and will interrupt the training process every 8 hours or so. For uninterrupted training, consider using a preemptible TPUv2 instance. 

**Credits**
This tutorial is adapted from https://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379

MIT License

Copyright (c) [2019] [Diego de Vargas Feijo]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

## Step 1: Setting Up the Environment
Install dependencies, import globally required packages and authorize with Google Account to access Colab TPU.

In [0]:
!pip install --upgrade -q sentencepiece

import json
import logging
import nltk
import os
import random
import sentencepiece as spm
import sys
import tensorflow as tf

from glob import glob
from google.colab import auth

auth.authenticate_user()
  
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s:  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

[?25l[K     |▎                               | 10kB 21.6MB/s eta 0:00:01[K     |▋                               | 20kB 2.2MB/s eta 0:00:01[K     |█                               | 30kB 3.2MB/s eta 0:00:01[K     |█▎                              | 40kB 2.1MB/s eta 0:00:01[K     |█▋                              | 51kB 2.6MB/s eta 0:00:01[K     |██                              | 61kB 3.1MB/s eta 0:00:01[K     |██▏                             | 71kB 3.6MB/s eta 0:00:01[K     |██▌                             | 81kB 4.0MB/s eta 0:00:01[K     |██▉                             | 92kB 4.5MB/s eta 0:00:01[K     |███▏                            | 102kB 3.5MB/s eta 0:00:01[K     |███▌                            | 112kB 3.5MB/s eta 0:00:01[K     |███▉                            | 122kB 3.5MB/s eta 0:00:01[K     |████                            | 133kB 3.5MB/s eta 0:00:01[K     |████▍                           | 143kB 3.5MB/s eta 0:00:01[K     |████▊                     

2019-12-09 19:16:48,080:  Using TPU runtime
2019-12-09 19:16:48,082:  TPU address is grpc://10.80.30.250:8470


Clone and Patch ALBERT sources

Some Albert sources use deprecated API that generate a lot of warnings. We also make some minor changes to the scripts can run smoothly on Colab.

In [0]:
# Clone the repository
!test -d ALBERT || git clone https://github.com/google-research/ALBERT.git ALBERT

# Avoid deprecated warnings
!sed -i 's/tf.logging/tf.compat.v1.logging/' ALBERT/*.py
!sed -i 's/tf.app.run/tf.compat.v1.app.run/' ALBERT/*.py

# Avoid error when the line contains only one number
!sed -i 's/i.lower()/str(i).lower()/' ALBERT/create_pretraining_data.py

# Create Dummy flag (Colab Bug)
!sed -i 's/FLAGS = flags.FLAGS/FLAGS=flags.FLAGS\n\nflags.DEFINE_string("f", "", "Dummy flag. Not used.")/' ALBERT/run_pretraining.py

# Mute too verbose output
!sed -i 's/tf.compat.v1.logging.info/# tf.compat.v1.logging.info/' ALBERT/tokenization.py

Cloning into 'ALBERT'...
remote: Enumerating objects: 56, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 56 (delta 28), reused 37 (delta 9), pack-reused 0[K
Unpacking objects: 100% (56/56), done.


In [0]:
if not 'ALBERT' in sys.path:
  sys.path += ['ALBERT']

import modeling, optimization, tokenization

2019-12-09 19:16:57,732:  From ALBERT/lamb_optimizer.py:33: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.



## Step 2: Download and Prepare Data

Wikipedia dump is available in XML format. We need to extract the raw text from it.


In [0]:
LANG = "pt"  #@param ['en', 'es', 'it', 'fr', 'pt']

The latest Wikipedia dump can be from the day 1 or 20, but the date when dump is finished can vary. Instead of complicated inspecting in the page, we are guessing when the dump is ready.

In [0]:
import datetime

def get_last_dump():
  today = datetime.datetime.now()

  if today.day > 8 and today.day < 25:
    day = 1
    month = today.month
  elif today.day >= 25:
    day = 20
    month = today.month
  else:
    day = 1
    month = today.month - 1
  return '{}{:02d}{:02d}'.format(today.year, month, day)
  

In [0]:
last_dump = get_last_dump()
corpus = tf.keras.utils.get_file(
    "{}wiki.bz2".format(LANG),
    "https://dumps.wikimedia.org/{}wiki/{}/{}wiki-{}-pages-articles-multistream.xml.bz2".format(
        LANG,
        last_dump,
        LANG,
        last_dump
    ))
!bzip2 -d {corpus}

Downloading data from https://dumps.wikimedia.org/ptwiki/20191201/ptwiki-20191201-pages-articles-multistream.xml.bz2


## Step 3: Extract Raw Text
Uses WikiExtractor to remove XML tags and keep only raw text.

In [0]:
!test -d wikiextractor || git clone https://github.com/attardi/wikiextractor.git

Cloning into 'wikiextractor'...
remote: Enumerating objects: 607, done.[K
Receiving objects:   0% (1/607)   Receiving objects:   1% (7/607)   Receiving objects:   2% (13/607)   Receiving objects:   3% (19/607)   Receiving objects:   4% (25/607)   Receiving objects:   5% (31/607)   Receiving objects:   6% (37/607)   Receiving objects:   7% (43/607)   Receiving objects:   8% (49/607)   Receiving objects:   9% (55/607)   Receiving objects:  10% (61/607)   Receiving objects:  11% (67/607)   Receiving objects:  12% (73/607)   Receiving objects:  13% (79/607)   Receiving objects:  14% (85/607)   Receiving objects:  15% (92/607)   Receiving objects:  16% (98/607)   Receiving objects:  17% (104/607)   Receiving objects:  18% (110/607)   Receiving objects:  19% (116/607)   Receiving objects:  20% (122/607)   Receiving objects:  21% (128/607)   Receiving objects:  22% (134/607)   Receiving objects:  23% (140/607)   Receiving objects:  24% (146/607)   Receiving objects: 

In [0]:
WIKI_INPUT_FILE, _ = os.path.splitext(corpus)
WIKI_EXTRACTED_DIR = "wikimedia" 

tf.io.gfile.makedirs(WIKI_EXTRACTED_DIR)
!python3 wikiextractor/WikiExtractor.py -q -c -o {WIKI_EXTRACTED_DIR} {WIKI_INPUT_FILE}

In [0]:
import nltk
nltk.download('punkt')

# Snowball Stemmers
LANG_STM = "portuguese" # @param ['danish', 'english', 'finnish', 'french', 'german', 'hugarian', 'italian', 'norwegian', 'porter', 'portuguese', 'romanian', 'russian', 'spanish', 'swedish']

sent_tokenizer = nltk.data.load('tokenizers/punkt/{}.pickle'.format(LANG_STM))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Prepare input for create_pretraining_data script.

Input file format requires:
- One sentence per line. These should ideally be actual sentences, not entire paragraphs or arbitrary spans of text. (Because we use the sentence boundaries for the "next sentence prediction" task).
- Blank lines between documents. Document boundaries are needed so that the "next sentence prediction" task doesn't span between documents.


In [0]:
import bz2

PRC_DATA_FPATH = "proc_wikimedia.txt" #@param {type: "string"}

with open(PRC_DATA_FPATH, "w", encoding="utf-8") as fo:
  for group in os.listdir(WIKI_EXTRACTED_DIR):
    basedir = os.path.join(WIKI_EXTRACTED_DIR, group)
    for fz in os.listdir(basedir):
      if fz.endswith('.bz2'):
        with bz2.BZ2File(os.path.join(basedir, fz), 'r') as fi:
          contents = fi.read()
        is_title = False
        text = contents.decode('utf-8')
        for l in text.splitlines():
          # print (l)
          if l.startswith('</doc>'):
            # Empty line for each new document
            fo.write("\n")
          elif len(l) == 1:
            # Empty lines must be ignored
            pass
          elif l.startswith('<doc'):
            # After the heading, there is a title
            is_title = True
          elif is_title:
            # Ignore this line, reset variable
            is_title = False
          else:
            # Wikipedia uses multiple sentences in on line
            # We need to split one sentence per line
            sentences = sent_tokenizer.tokenize(l)
            for sentence in sentences:
              fo.write(sentence + "\n")


In [0]:
BUCKET_NAME = "<Insert Bucket Name Here>" # @param string

!head {PRC_DATA_FPATH}
!test -f {PRC_DATA_FPATH}.gz || gzip < {PRC_DATA_FPATH} > {PRC_DATA_FPATH}.gz
tf.gfile.MakeDirs("gs://{}/datasets/".format(BUCKET_NAME))
!gsutil -m cp {PRC_DATA_FPATH}.gz gs://{BUCKET_NAME}/datasets/

O termo Stock Car refere-se a uma forma de corrida automobilística popular principalmente nos Estados Unidos, no Canadá, no México, na Grã-Bretanha, na Austrália, no Brasil, entre outros países.
Os "Stock Cars" em seu senso original descrevem um automóvel de passeio usado em competições que não possui modificações especiais para corrida.
Posteriormente passaram a ser carros de passeios modificados para corridas.
Atualmente, muitas categorias usam carros especiais feitos exclusivamente para as corridas, mas mantendo o design parecido com os carros de passeio, ao contrário dos carros de monoposto, com pneus a mostra.
Atualmente, as corridas de Stock Cars são considerados como sendo uma categoria das corridas de turismo.
Também chamado de "street stock", são os veículos que podem ser comprados pelo público em geral, os chamados carros de produção, geralmente somente modificações por questões de segurança são permitidas nesses modelos.
Semelhantes ao "pure stock", mas com permissão de modi

## Step 4: Build SentencePiece Model
In this step we will be generating the config files and the encoder to covert text to integers.

In [0]:
BUCKET_NAME = "<Insert Bucket Name Here>" # @param {type: "string"}

MODEL_DIR = "albert_cased_L-12_H-768_A-12" #@param {type: "string"}
VOC_SIZE = 30000 #@param {type:"integer"}

!test -f {PRC_DATA_FPATH}.gz || gsutil -m cp gs://{BUCKET_NAME}/datasets/{PRC_DATA_FPATH}.gz .
!test -f {PRC_DATA_FPATH} || gzip -d < {PRC_DATA_FPATH}.gz > {PRC_DATA_FPATH}

Copying gs://diego-feijo_datasets/datasets/proc_wikimedia.txt.gz...
/ [1/1 files][550.1 MiB/550.1 MiB] 100% Done                                    
Operation completed over 1 objects/550.1 MiB.                                    


Build ALBERT Configuration Base model:
- Base Model: https://tfhub.dev/google/albert_base/2
- Large Model: https://tfhub.dev/google/albert_large/2
- X-Large Model: https://tfhub.dev/google/albert_xlarge/2
- XX-Large Model: https://tfhub.dev/google/albert_xxlarge/2

It is not feasible to create pre-training data for models bigger than Large using Colab. 

In [0]:
# use this for ALBERT-base
albert_config = {
  "attention_probs_dropout_prob": 0.1, 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "embedding_size": 128,
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_hidden_groups": 1,
  "net_structure_type": 0,
  "gap_size": 0, 
  "num_memory_blocks": 0, 
  "inner_group_num": 1,
  "down_scale_factor": 1,
  "type_vocab_size": 2,
  "vocab_size": VOC_SIZE
}

with open("albert_config.json", "w") as fo:
  json.dump(albert_config, fo, indent=2)
!gsutil -m cp albert_config.json gs://{BUCKET_NAME}/{MODEL_DIR}/

Copying file://albert_config.json [Content-Type=application/json]...
/ [1/1 files][  483.0 B/  483.0 B] 100% Done                                    
Operation completed over 1 objects/483.0 B.                                      


[SentencePiece](https://github.com/google/sentencepiece) will be used to encode the text.

We need to train the SentencePiece model and build our vocabulary. It is required a lot of RAM. Even 35GB RAM offered by Colab may not be enough if all the raw text is used. We use SUBSAMPLE_SIZE to control how much memory is used.

In case of Out of Memory, it is possible to reduce SUBSAMPLE_SIZE.

The VOC_SIZE used by monolingual BERT and ALBERT papers are 30000. The multilingual uses 129000 tokens. It is not clear if increasing the VOC_SIZE will improve the model.

NUM_PLACEHOLDERS can be used after the pre-training during the fine-tunning.

In [0]:
PRC_DATA_FPATH = "proc_wikimedia.txt"  #@param {type: "string"}
MODEL_PREFIX = 'tokenizer' #@param {type:"string"}
SUBSAMPLE_SIZE = 10000000 #@param {type:"integer"}
# Number of reserved tokens at end of vocabulary
# This should only be used when training data contains a small but very
# frequent tokens.
NUM_PLACEHOLDERS = 0 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--pad_piece=[PAD] '
               '--unk_piece=[UNK] '
               '--pad_id=0 --unk_id=1 --user_defined_symbols=[CLS],[SEP],[MASK] ' 
               '--bos_id=-1 --eos_id=-1 ').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

In [0]:
DEMO_MODE = True # @param {type: "boolean"}

# Reduce the number of lines to train faster
if DEMO_MODE:
  !head -1000000 {PRC_DATA_FPATH} > {PRC_DATA_FPATH}.tmp
  !mv {PRC_DATA_FPATH}.tmp {PRC_DATA_FPATH}

This training can take a while. Grab a coffee.

In [0]:
spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

In [0]:
!test -f {MODEL_PREFIX}.tar.gz && rm {MODEL_PREFIX}.tar.gz
!tar czvf {MODEL_PREFIX}.tar.gz {MODEL_PREFIX}.*
tf.gfile.MakeDirs("gs://{}/{}/".format(BUCKET_NAME, MODEL_DIR))
!gsutil -m cp {MODEL_PREFIX}.tar.gz gs://{BUCKET_NAME}/{MODEL_DIR}/

tokenizer.model
tokenizer.vocab
Copying file://tokenizer.tar.gz [Content-Type=application/x-tar]...
/ [1/1 files][575.0 KiB/575.0 KiB] 100% Done                                    
Operation completed over 1 objects/575.0 KiB.                                    


Now let's see how we can make SentencePiece tokenizer work for the BERT model. 

SentencePiece has created two files: tokenizer.model and tokenizer.vocab. Let's have a look at the learned vocabulary:

In [0]:
VOC_FNAME = "{}.vocab".format(MODEL_PREFIX)
MDL_FNAME = "{}.model".format(MODEL_PREFIX)

!head {VOC_FNAME}
!wc -l {VOC_FNAME}

[PAD]	0
[UNK]	0
[CLS]	0
[SEP]	0
[MASK]	0
,	-3.0199
▁de	-3.14827
.	-3.52606
▁a	-3.82674
▁e	-3.85135
30000 tokenizer.vocab


Now let's see how the new vocabulary works in practice:

In [0]:
testcase = "[CLS] [MASK] Sentença de mérito. [SEP] Embargos de declaração 普通话.[SEP]"

bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME, do_lower_case=False, spm_model_file=MDL_FNAME)
tokens = bert_tokenizer.tokenize(testcase)
ids = bert_tokenizer.convert_tokens_to_ids(tokens)
print(tokens)
print(ids)

['▁', '[CLS]', '▁', '[MASK]', '▁Se', 'nte', 'nça', '▁de', '▁mérito', '.', '▁', '[SEP]', '▁Em', 'bar', 'gos', '▁de', '▁declaração', '▁', '普通话', '.', '[SEP]']
[19, 2, 19, 4, 332, 291, 2601, 6, 10323, 7, 19, 3, 43, 2466, 3362, 6, 5711, 19, 1, 7, 3]


Looking good!

## Step 5: Generate Pre-training Data
Pre-training data is a collection of tfrecord files containing replacing the text by its encoded ids, masking and setting if one sentence follows the other.

So, a text:

- "This is just one of many samples. This a sentence that follows. " 

would be converted to:
- "\[CLS\] \[SEP\] 23 45 \[MASK\] ... \[SEP\] 67 89 ... \[PAD\] \[PAD\]"


SentencePiece model is used to encode text into Ids. The create_pretraining_data will append special tokens ('\[CLS\]', '\[SEP\]', '\[UNK\]', '\[PAD\]').

In [0]:
BUCKET_NAME = "<Insert Bucket Name Here>" # @param {type: "string"}

MODEL_DIR = 'albert_cased_L-12_H-768_A-12' #@param {type:"string"}
MODEL_PREFIX = 'tokenizer' #@param {type:"string"}
VOC_FNAME = "{}.vocab".format(MODEL_PREFIX)
MDL_FNAME = "{}.model".format(MODEL_PREFIX)
PRC_DATA_FPATH = "proc_wikimedia.txt" #@param {type:"string"}

!test -f {PRC_DATA_FPATH}.gz || gsutil -m cp gs://{BUCKET_NAME}/datasets/{PRC_DATA_FPATH}.gz .
!test -f {PRC_DATA_FPATH} || gzip -d < {PRC_DATA_FPATH}.gz > {PRC_DATA_FPATH}
!test -f {MODEL_PREFIX}.tar.gz || gsutil -m cp gs://{BUCKET_NAME}/{MODEL_DIR}/{MODEL_PREFIX}.tar.gz .
!test -f {VOC_FNAME} || tar xzvf {MODEL_PREFIX}.tar.gz

In [0]:
DEMO_MODE = True # @param {type: "boolean"}

# Reduce the number of lines to train faster
if DEMO_MODE:
  !head -1000000 {PRC_DATA_FPATH} > {PRC_DATA_FPATH}.tmp
  !mv {PRC_DATA_FPATH}.tmp {PRC_DATA_FPATH}

Since our corpus can be large, we will split it into shards:

In [0]:
!rm -rf ./shards
!mkdir ./shards
!split -a 4 -l 256000 -d {PRC_DATA_FPATH} ./shards/shard_
!ls ./shards/

shard_0000  shard_0001	shard_0002  shard_0003


The **MAX_SEQ_LENGTH** (maximum sequence length) supported for the model is 512, but training time will be a lot slower because the complexity is quadratic to the length of sentences. Albert authors trained 90% of time using length 128 and the remaining using 512.

To simulate this behaviour, it is necessary to create training using length 128,and then create pre-training data again using 512 and change the configuration file.

The **DUPE_FACTOR** defines how many times each sequence will be used. Each sequence is randomly masked so it is a good use of the data to have as many duplicates as possible. However, using values larger than 20 may generate files larger than 1GB per shard. Larger files will not make the pre-training to run slowly, but will require a lot of space.


In [0]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param {type: "number"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
# Strip diacritics and Lowercase
DO_LOWER_CASE = False #@param {type:"boolean"}
DO_WHOLE_WORD_MASK = True #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
PRETRAINING_DIR = "gs://{}/{}/pretraining_data_{}".format(BUCKET_NAME, MODEL_DIR, MAX_SEQ_LENGTH)
DUPE_FACTOR = 4 #@param {type:"integer"}

In [0]:
PRETRAINING_DIR

'gs://diego-feijo_datasets/albert_cased_L-12_H-768_A-12/pretraining_data_128'

Now, for each shard we need to call *create_pretraining_data.py* script. To that end, we will employ the  *xargs* command. 

This step will take a while to run. We will be saving generated data from each shards in the permanent storage.

If you need to resume this step, you can check the bucket for generated files and manually delete the local shards that were already generated.

In [0]:
XARGS_CMD = ('ls ./shards | '
      'xargs -n 1 -P {} -I{} '
      'python3 ALBERT/create_pretraining_data.py '
      '--input_file=./shards/{} '
      '--output_file={}/{}.tfrecord '
      '--vocab_file={} '
      '--spm_model_file={} '
      '--do_lower_case={} '
      '--do_whole_word_mask={} '
      '--max_predictions_per_seq={} '
      '--max_seq_length={} '
      '--masked_lm_prob={} '
      '--dupe_factor={} ')
XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}',
                             PRETRAINING_DIR, '{}', 
                             VOC_FNAME, MDL_FNAME, DO_LOWER_CASE,
                             DO_WHOLE_WORD_MASK, MAX_PREDICTIONS, 
                             MAX_SEQ_LENGTH, MASKED_LM_PROB, DUPE_FACTOR)
!$XARGS_CMD


W1209 17:25:41.247547 140407190394752 module_wrapper.py:139] From ALBERT/create_pretraining_data.py:618: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W1209 17:25:41.308561 140687693150080 module_wrapper.py:139] From ALBERT/create_pretraining_data.py:618: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W1209 17:25:41.325810 140407190394752 module_wrapper.py:139] From ALBERT/create_pretraining_data.py:626: The name tf.gfile.Glob is deprecated. Please use tf.io.gfile.glob instead.

INFO:tensorflow:*** Reading from input files ***
I1209 17:25:41.326774 140407190394752 create_pretraining_data.py:628] *** Reading from input files ***
INFO:tensorflow:  ./shards/shard_0001
I1209 17:25:41.326937 140407190394752 create_pretraining_data.py:630]   ./shards/shard_0001

W1209 17:25:41.327742 140407190394752 module_wrapper.py:139] From ALBERT/create_pretraining_data.py:228: The name tf.gfile.GFile is deprecated. Pleas

## Step 6: Training the Model

If you need to resume from an interrupted training, you may skip steps 2-5 and proceed from here.

In [0]:
BUCKET_NAME = "<Insert Bucket Name Here>" # @param {type: "string"}

MODEL_DIR = 'albert_cased_L-12_H-768_A-12' #@param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 256 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

PRETRAINING_DIR = "pretraining_data_{}".format(MAX_SEQ_LENGTH)

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 0.00176
TRAIN_STEPS = 175000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 5000 #@param {type:"integer"}
NUM_TPU_CORES = 8

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

ALBERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(ALBERT_GCS_DIR, PRETRAINING_DIR)

CONFIG_FILE = os.path.join(ALBERT_GCS_DIR, "albert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(ALBERT_GCS_DIR)

albert_config = modeling.AlbertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))

2019-12-09 19:24:57,946:  From ALBERT/modeling.py:115: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.

2019-12-09 19:24:58,268:  Using checkpoint: None
2019-12-09 19:24:58,274:  Using 4 data shards


Prepare the training run configuration, build the estimator and input function, power up the bass cannon.

In [0]:
from run_pretraining import input_fn_builder, model_fn_builder


model_fn = model_fn_builder(
      albert_config=albert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=3125,
      use_tpu=USE_TPU,
      optimizer="lamb",
      poly_power=1.0,
      start_warmup_step=0,
      use_one_hot_embeddings=USE_TPU)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=ALBERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

2019-12-09 19:25:10,812:  Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7f5c7a8df378>) includes params argument, but params are not passed to Estimator.
2019-12-09 19:25:10,814:  Using config: {'_model_dir': 'gs://diego-feijo_datasets/albert_cased_L-12_H-768_A-12', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.80.30.250:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object 

Fire!

In [0]:
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

2019-12-09 19:25:35,194:  Querying Tensorflow master (grpc://10.80.30.250:8470) for TPU system metadata.
2019-12-09 19:25:35,210:  Found TPU system:
2019-12-09 19:25:35,211:  *** Num TPU Cores: 8
2019-12-09 19:25:35,211:  *** Num TPU Workers: 1
2019-12-09 19:25:35,212:  *** Num TPU Cores Per Worker: 8
2019-12-09 19:25:35,212:  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 13302816656652693886)
2019-12-09 19:25:35,214:  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 5639868938437605861)
2019-12-09 19:25:35,214:  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 5966812053291465332)
2019-12-09 19:25:35,215:  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 10829184446052066539)
2019-12-09 19:25:35,215:  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU

Training the model with the default parameters for 175k steps will take ~20 hours. 

In case the kernel is restarted, you may always continue training from the latest checkpoint. 

This concludes the guide to pre-training BERT from scratch on a cloud TPU. However, the really fun stuff is still  to come, so stay tuned.

Keep learning!