<a href="https://colab.research.google.com/github/hellozhaojian/transformers/blob/master/bert_fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**简介**

本文主要内容如何使用TPU完成对bert的基准模型进行fine-tuning。
前置条件：
1. 在google cloud里有一个项目。
   本次教程中项目名称为 pre-train-bert-sogou； 我们在bucket bert-sogou-pretrain中放置数据，配置和模型。


主要步骤如下：

1. 将及基准模型和数据放置到google cloud项目中。

2. 将数据准备为tf-record的模式

3. 训练模型。



**数据、配置、模型准备**


* 登录google cloud

In [1]:
! gcloud auth application-default login
from google.colab import auth


Go to the following link in your browser:

    https://accounts.google.com/o/oauth2/auth?code_challenge=G-n10zWLrBMaKTjdgJXeneERZa4qOVZvTix1YJnG7GI&prompt=select_account&code_challenge_method=S256&access_type=offline&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Faccounts.reauth


Enter verification code: 4/tAGWALB9Dv2KWiCLNhOju8CIy8wYSbHK51RFJP2BKjt3SaaYVk4VhO8

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests
Application Default Credentials.

To generate an access token for other uses, run:
  gcloud auth application-default print-access-token


* 代码准备

In [2]:
! git clone https://github.com/google-research/bert
! ls ./

fatal: destination path 'bert' already exists and is not an empty directory.
bert  sample_data


* 数据模型准备

In [8]:
GOOGLE_CLOUD_PROJECT_NAME = "pre-train-bert-sogou" #@param {type: "string" }
BUCKET_NAME = "bert-sogou-pretrain"  #@param {type: "string"}
BASE_MODEL_DIR = "fine_tuning/base_model" #@param {type: "string"}
NEW_MODEL_DIR = "fine_tuning/model" #@param {type: "string"}
MODEL_NAME = "chinese_L-12_H-768_A-12" #@param {type: "string"}
INPUT_DATA_DIR = "fine_tuning/data/zh_wiki_news_2016" #@param {type: "string"}
PROCESSES = 5 #@param {type: "integer"}
DO_LOWER_CASE = True
MAX_SEQ_LENGTH = 128 #@param {type : "integer"}
MASKED_LM_PROB = 0.15 #@param {type: "number" }
# xxxx
MAX_PREDICTIONS = 20 #@param {type: "integer"



! gcloud config set project pre-train-bert-sogou
base_model_name = "gs://{}/{}/{}".format(BUCKET_NAME, BASE_MODEL_DIR, MODEL_NAME)
fine_tuning_name = "gs://{}/{}/{}".format(BUCKET_NAME, NEW_MODEL_DIR, MODEL_NAME)
! gsutil rm -rf $fine_tuning_name
! gsutil cp -r $base_model_name $fine_tuning_name

Updated property [core/project].
Removing gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/bert_config.json#1573303171402418...
Removing gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001#1573303171634993...
Removing gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/bert_model.ckpt.index#1573303171825338...
Removing gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/bert_model.ckpt.meta#1573303171987890...
/ [4 objects]                                                                   
==> NOTE: You are performing a sequence of gsutil operations that may
run significantly faster if you instead use gsutil -m rm ... Please
see the -m section under "gsutil help options" for further information
about when gsutil -m can be advantageous.

Removing gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/checkpoint#1573303172143031...
Removing gs://bert-sogou-pretrain/fine_tuning/m

**准备tf-record数据**

In [0]:
XARGS_CMD = ("gsutil ls gs://{}/{}/*_* | "
             "awk 'BEGIN{{FS=\"/\"}}{{print $NF}}' | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=gs://{}/{}/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

VOC_FNAME = "gs://{}/{}/{}/vocab.txt".format(BUCKET_NAME, NEW_MODEL_DIR, MODEL_NAME)
TF_RECORD_DIR = "gs://{}/{}_tfrecord".format(BUCKET_NAME, INPUT_DATA_DIR)

XARGS_CMD = XARGS_CMD.format(BUCKET_NAME, INPUT_DATA_DIR, 
                             PROCESSES, '{}',  BUCKET_NAME, INPUT_DATA_DIR, '{}', 
                             TF_RECORD_DIR, '{}',
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

print (XARGS_CMD)
! gsutil mkdir -p $TF_RECORD_DIR 
! $XARGS_CMD


gsutil ls gs://bert-sogou-pretrain/fine_tuning/data/zh_wiki_news_2016/*_* | awk 'BEGIN{FS="/"}{print $NF}' | xargs -n 1 -P 5 -I{} python3 bert/create_pretraining_data.py --input_file=gs://bert-sogou-pretrain/fine_tuning/data/zh_wiki_news_2016/{} --output_file=gs://bert-sogou-pretrain/fine_tuning/data/zh_wiki_news_2016_tfrecord/{}.tfrecord --vocab_file=gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/vocab.txt --do_lower_case=True --max_predictions_per_seq=20 --max_seq_length=128 --masked_lm_prob=0.15 --random_seed=34 --dupe_factor=5
CommandException: The mb command requires at least 1 argument. Usage:

  gsutil mb [-b <on|off>] [-c class] [-l location] [-p proj_id]
            [--retention time] url...

For additional help run:
  gsutil help mb


W1109 12:44:16.302721 139643941230464 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W1109 12:44:16.3

**训练模型**

* 链接TPU


In [5]:
import os
import logging

import tensorflow as tf

log = logging.getLogger("pre-train-bert")
auth.authenticate_user()

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    print(TPU_ADDRESS)
    log.info('TPU address is ' + TPU_ADDRESS)
    tf.contrib.cloud.configure_gcs(session)
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False
print(USE_TPU)

! gsutil ls gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12/bert_model.ckpt

grpc://10.85.214.218:8470
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

True
CommandException: One or more URLs matched no objects.


* 设置训练参数

In [6]:
from bert import modeling, optimization, tokenization

# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 25 #@param {type:"integer"}
NUM_TPU_CORES = 8



BERT_GCS_DIR = fine_tuning_name+"_running"

! gsutil mkdir $BERT_GCS_DIR

DATA_GCS_DIR = TF_RECORD_DIR

VOCAB_FILE = VOC_FNAME

CONFIG_FILE = "gs://{}/{}/{}/bert_config.json".format(BUCKET_NAME, BASE_MODEL_DIR, MODEL_NAME)


#! gsutil ls $BERT_GCS_DIR

INIT_CHECKPOINT = "{}/bert_model.ckpt".format(base_model_name)
#"gs://bert-sogou-pretrain/fine_tuning/base_model/chinese_L-12_H-768_A-12/bert_model.ckpt" 
TMP_INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)
if TMP_INIT_CHECKPOINT is not None:
    INIT_CHECKPOINT = TMP_INIT_CHECKPOINT


bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))

log.info("Using {} data shards".format(len(input_files)))

! gsutil ls $INIT_CHECKPOINT*




CommandException: The mb command requires a URL that specifies a bucket.
"gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12_running" is not valid.

gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12_running/model.ckpt-25.data-00000-of-00001
gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12_running/model.ckpt-25.index
gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12_running/model.ckpt-25.meta


* 训练模型

In [7]:
import sys
sys.path.append("bert")
from bert.run_pretraining import input_fn_builder, model_fn_builder
from bert import modeling, optimization, tokenization


model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
                

INFO:tensorflow:Using config: {'_model_dir': 'gs://bert-sogou-pretrain/fine_tuning/model/chinese_L-12_H-768_A-12_running', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 25, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.85.214.218:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f57d8a8ee10>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.85.214.218:8470', '_evaluation_master': 'grpc://10.85.214.218

KeyboardInterrupt: ignored