From 6867c3c11471ad9c21e48e4e0b4f2b0b6ed394d1 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 15 Jun 2022 14:37:06 -0700 Subject: [PATCH 1/3] Add cloud training support for BERT example --- examples/bert/bert_model.py | 310 +-------------------------- examples/bert/bert_train.py | 136 +++++++++--- examples/utils/google_cloud_utils.py | 57 +++++ 3 files changed, 172 insertions(+), 331 deletions(-) create mode 100644 examples/utils/google_cloud_utils.py diff --git a/examples/bert/bert_model.py b/examples/bert/bert_model.py index 69cb370803..237d00a77a 100644 --- a/examples/bert/bert_model.py +++ b/examples/bert/bert_model.py @@ -46,297 +46,6 @@ def make_attention_mask(inputs, mask): return tf.ones([batch_size, from_seq_length, 1], dtype=inputs.dtype) * mask -class TransformerEncoderBlock(keras.layers.Layer): - """TransformerEncoderBlock layer. - - This layer implements the Transformer Encoder from - "Attention Is All You Need". (https://arxiv.org/abs/1706.03762), - which combines a `keras.layers.MultiHeadAttention` layer with a - two-layer feedforward network. - - Args: - num_attention_heads: Number of attention heads. - inner_size: The output dimension of the first Dense layer in a - two-layer feedforward network. - inner_activation: The activation for the first Dense layer in a - two-layer feedforward network. - output_range: the sequence output range, [0, output_range) for - slicing the target sequence. `None` means the target sequence is - not sliced. - kernel_initializer: Initializer for dense layer kernels. - bias_initializer: Initializer for dense layer biases. - kernel_regularizer: Regularizer for dense layer kernels. - bias_regularizer: Regularizer for dense layer biases. - activity_regularizer: Regularizer for dense layer activity. - kernel_constraint: Constraint for dense layer kernels. - bias_constraint: Constraint for dense layer kernels. - use_bias: Whether to enable use_bias in attention layer. If set - False, use_bias in attention layer is disabled. - norm_first: Whether to normalize inputs to attention and - intermediate dense layers. If set False, output of attention and - intermediate dense layers is normalized. - norm_epsilon: Epsilon value to initialize normalization layers. - hidden_dropout: Dropout probability for the post-attention and - output dropout. - attention_dropout: Dropout probability for within the attention - layer. - inner_dropout: Dropout probability for the first Dense layer in a - two-layer feedforward network. - attention_initializer: Initializer for kernels of attention layers. - If set `None`, attention layers use kernel_initializer as - initializer for kernel. - attention_axes: axes over which the attention is applied. `None` - means attention over all axes, but batch, heads, and features. - **kwargs: keyword arguments. - - References: - [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - [BERT: Pre-training of Deep Bidirectional Transformers for Language - Understanding](https://arxiv.org/abs/1810.04805) - """ - - def __init__( - self, - num_attention_heads, - inner_size, - inner_activation, - output_range=None, - kernel_initializer="glorot_uniform", - bias_initializer="zeros", - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - use_bias=True, - norm_first=False, - norm_epsilon=1e-12, - hidden_dropout=0.0, - attention_dropout=0.0, - inner_dropout=0.0, - attention_initializer=None, - attention_axes=None, - **kwargs, - ): - super().__init__(**kwargs) - - self._num_heads = num_attention_heads - self._inner_size = inner_size - self._inner_activation = inner_activation - self._attention_dropout = attention_dropout - self._attention_dropout_rate = attention_dropout - self._hidden_dropout = hidden_dropout - self._hidden_dropout_rate = hidden_dropout - self._output_range = output_range - self._kernel_initializer = keras.initializers.get(kernel_initializer) - self._bias_initializer = keras.initializers.get(bias_initializer) - self._kernel_regularizer = keras.regularizers.get(kernel_regularizer) - self._bias_regularizer = keras.regularizers.get(bias_regularizer) - self._activity_regularizer = keras.regularizers.get( - activity_regularizer - ) - self._kernel_constraint = keras.constraints.get(kernel_constraint) - self._bias_constraint = keras.constraints.get(bias_constraint) - self._use_bias = use_bias - self._norm_first = norm_first - self._norm_epsilon = norm_epsilon - self._inner_dropout = inner_dropout - if attention_initializer: - self._attention_initializer = keras.initializers.get( - attention_initializer - ) - else: - self._attention_initializer = self._kernel_initializer - self._attention_axes = attention_axes - - def build(self, input_shape): - if isinstance(input_shape, tf.TensorShape): - input_tensor_shape = input_shape - elif isinstance(input_shape, (list, tuple)): - input_tensor_shape = tf.TensorShape(input_shape[0]) - else: - raise ValueError( - f"Unknown input shape type. Received: {type(input_shape)}" - ) - einsum_equation = "abc,cd->abd" - if len(input_tensor_shape.as_list()) > 3: - einsum_equation = "...bc,cd->...bd" - hidden_size = input_tensor_shape[-1] - if hidden_size % self._num_heads != 0: - raise ValueError( - f"The input size {hidden_size} is not a multiple of the number " - f"of attention heads {self._num_heads}" - ) - self._attention_head_size = int(hidden_size // self._num_heads) - common_kwargs = dict( - bias_initializer=self._bias_initializer, - kernel_regularizer=self._kernel_regularizer, - bias_regularizer=self._bias_regularizer, - activity_regularizer=self._activity_regularizer, - kernel_constraint=self._kernel_constraint, - bias_constraint=self._bias_constraint, - ) - self._attention_layer = keras.layers.MultiHeadAttention( - num_heads=self._num_heads, - key_dim=self._attention_head_size, - dropout=self._attention_dropout, - use_bias=self._use_bias, - kernel_initializer=self._attention_initializer, - attention_axes=self._attention_axes, - name="self_attention", - **common_kwargs, - ) - self._attention_dropout = keras.layers.Dropout( - rate=self._hidden_dropout - ) - # Use float32 in layernorm for numeric stability. It is probably safe in - # mixed_float16, but we haven't validated this yet. - self._attention_layer_norm = keras.layers.LayerNormalization( - name="self_attention_layer_norm", - axis=-1, - epsilon=self._norm_epsilon, - dtype=tf.float32, - ) - self._intermediate_dense = keras.layers.experimental.EinsumDense( - einsum_equation, - output_shape=(None, self._inner_size), - bias_axes="d", - kernel_initializer=self._kernel_initializer, - name="intermediate", - **common_kwargs, - ) - policy = keras.mixed_precision.global_policy() - if policy.name == "mixed_bfloat16": - # bfloat16 causes BERT with the LAMB optimizer to not converge - # as well, so we use float32. - # TODO(b/154538392): Investigate this. - policy = tf.float32 - self._intermediate_activation_layer = keras.layers.Activation( - self._inner_activation, dtype=policy - ) - self._inner_dropout_layer = keras.layers.Dropout( - rate=self._inner_dropout - ) - self._output_dense = keras.layers.experimental.EinsumDense( - einsum_equation, - output_shape=(None, hidden_size), - bias_axes="d", - name="output", - kernel_initializer=self._kernel_initializer, - **common_kwargs, - ) - self._hidden_dropout = keras.layers.Dropout(rate=self._hidden_dropout) - # Use float32 in layernorm for numeric stability. - self._output_layer_norm = keras.layers.LayerNormalization( - name="output_layer_norm", - axis=-1, - epsilon=self._norm_epsilon, - dtype=tf.float32, - ) - - super().build(input_shape) - - def get_config(self): - config = { - "num_attention_heads": self._num_heads, - "inner_size": self._inner_size, - "inner_activation": self._inner_activation, - "hidden_dropout": self._hidden_dropout_rate, - "attention_dropout": self._attention_dropout_rate, - "output_range": self._output_range, - "kernel_initializer": keras.initializers.serialize( - self._kernel_initializer - ), - "bias_initializer": keras.initializers.serialize( - self._bias_initializer - ), - "kernel_regularizer": keras.regularizers.serialize( - self._kernel_regularizer - ), - "bias_regularizer": keras.regularizers.serialize( - self._bias_regularizer - ), - "activity_regularizer": keras.regularizers.serialize( - self._activity_regularizer - ), - "kernel_constraint": keras.constraints.serialize( - self._kernel_constraint - ), - "bias_constraint": keras.constraints.serialize( - self._bias_constraint - ), - "use_bias": self._use_bias, - "norm_first": self._norm_first, - "norm_epsilon": self._norm_epsilon, - "inner_dropout": self._inner_dropout, - "attention_initializer": keras.initializers.serialize( - self._attention_initializer - ), - "attention_axes": self._attention_axes, - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) - - def call(self, query, key_value=None, attention_mask=None): - """Transformer self-attention encoder block call. - - Args: - query: The query for the multi-head attention layer. - key_value: Optional key/value tensor for multi-head attention. If - none supplied, the query will also be used. - attention_mask: Optional mask for the multi-head attention layer. - - Returns: - An output tensor with the same dimensions as input/query tensor. - """ - if self._output_range: - if self._norm_first: - source_tensor = query[:, 0 : self._output_range, :] - query = self._attention_layer_norm(query) - if key_value is not None: - key_value = self._attention_layer_norm(key_value) - target_tensor = query[:, 0 : self._output_range, :] - if attention_mask is not None: - attention_mask = attention_mask[:, 0 : self._output_range, :] - else: - if self._norm_first: - source_tensor = query - query = self._attention_layer_norm(query) - if key_value is not None: - key_value = self._attention_layer_norm(key_value) - target_tensor = query - - if key_value is None: - key_value = query - # TODO(mattdangerw): Use the build in masking mechanism. - attention_output = self._attention_layer( - query=target_tensor, value=key_value, attention_mask=attention_mask - ) - attention_output = self._attention_dropout(attention_output) - if self._norm_first: - attention_output = source_tensor + attention_output - else: - attention_output = self._attention_layer_norm( - target_tensor + attention_output - ) - if self._norm_first: - source_attention_output = attention_output - attention_output = self._output_layer_norm(attention_output) - inner_output = self._intermediate_dense(attention_output) - inner_output = self._intermediate_activation_layer(inner_output) - inner_output = self._inner_dropout_layer(inner_output) - layer_output = self._output_dense(inner_output) - layer_output = self._hidden_dropout(layer_output) - - if self._norm_first: - return source_attention_output + layer_output - - # During mixed precision training, layer norm output is always fp32 for - # now. Casts fp32 for the subsequent add. - layer_output = tf.cast(layer_output, tf.float32) - return self._output_layer_norm(layer_output + attention_output) - - # TODO(mattdangerw): This class is needed for TPU friendly embeddings, we should # remove it entirely and fix tf.keras.layers.Embedding as needed. class OnDeviceEmbedding(keras.layers.Layer): @@ -468,7 +177,6 @@ def __init__( initializer_range=0.02, max_sequence_length=512, type_vocab_size=2, - norm_first=False, **kwargs, ): super().__init__(**kwargs) @@ -487,7 +195,6 @@ def __init__( ) self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout - self.norm_first = norm_first self._embedding_layer = OnDeviceEmbedding( vocab_size=vocab_size, @@ -523,13 +230,11 @@ def __init__( self._transformer_layers = [] for i in range(num_layers): - layer = TransformerEncoderBlock( - num_attention_heads=num_attention_heads, - inner_size=inner_size, - inner_activation=self.inner_activation, - hidden_dropout=hidden_dropout, - attention_dropout=attention_dropout, - norm_first=norm_first, + layer = keras_nlp.layers.TransformerEncoder( + num_heads=num_attention_heads, + intermediate_dim=inner_size, + activation=self.inner_activation, + dropout=hidden_dropout, kernel_initializer=self.initializer, name="transformer/layer_%d" % i, ) @@ -558,11 +263,9 @@ def call(self, inputs): embeddings = self._embedding_norm_layer(embeddings) embeddings = self._embedding_dropout(embeddings) - attention_mask = make_attention_mask(embeddings, input_mask) - x = embeddings for layer in self._transformer_layers: - x = layer(x, attention_mask=attention_mask) + x = layer(x, padding_mask=input_mask) return x def get_embedding_table(self): @@ -585,7 +288,6 @@ def get_config(self): "hidden_dropout": self.hidden_dropout, "attention_dropout": self.attention_dropout, "initializer_range": self.initializer_range, - "norm_first": self.norm_first, } ) return config diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 77aca5911e..e6ff8f2e35 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -12,19 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os import shutil import sys +import google.cloud.logging import tensorflow as tf from absl import app from absl import flags +from absl import logging +from google.cloud import storage from tensorflow import keras from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG from examples.bert.bert_model import BertModel +from examples.utils.google_cloud_utils import list_blobs_with_prefix from examples.utils.scripting_utils import list_filenames_for_arg FLAGS = flags.FLAGS @@ -53,6 +58,30 @@ "Skip restoring from checkpoint if True", ) +flags.DEFINE_bool( + "tpu_name", + None, + "The TPU to connect to. If None, TPU will not be used.", +) + +flags.DEFINE_bool( + "use_cloud_storage", + False, + "If True, data I/O will use cloud storage instead of local disks.", +) + +flags.DEFINE_string( + "gcs_bucket", + None, + "Name of GCS bucket.", +) + +flags.DEFINE_string( + "tensorboard_log_path", + None, + "The path to save tensorboard log to.", +) + flags.DEFINE_string( "model_size", "tiny", @@ -371,11 +400,76 @@ def decode_record(record): return example +def get_checkpoint_callback(): + if FLAGS.checkpoint_save_directory is not None: + if FLAGS.use_cloud_storage: + storage_client = storage.Client() + bucket = storage_client.get_bucket(FLAGS.gcs_bucket) + blobs = bucket.list_blobs(prefix=FLAGS.checkpoint_save_directory) + if FLAGS.skip_restore: + for blob in blobs: + blob.delete() + checkpoint_path = ( + "gs://" + + FLAGS.gcs_bucket + + "/" + + FLAGS.checkpoint_save_directory + ) + return tf.keras.callbacks.BackupAndRestore( + backup_dir=checkpoint_path, + ) + + else: + if os.path.exists(FLAGS.checkpoint_save_directory): + if not os.path.isdir(FLAGS.checkpoint_save_directory): + raise ValueError( + "`checkpoint_save_directory` should be a directory, " + f"but {FLAGS.checkpoint_save_directory} is not a " + "directory. Please set `checkpoint_save_directory` as " + "a directory." + ) + + elif FLAGS.skip_restore: + # Clear up the directory if users want to skip restoring. + shutil.rmtree(FLAGS.checkpoint_save_directory) + checkpoint_path = FLAGS.checkpoint_save_directory + return tf.keras.callbacks.BackupAndRestore( + backup_dir=checkpoint_path, + ) + + +def get_tensorboard_callback(): + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + if FLAGS.use_cloud_storage: + log_dir = ( + "gs://" + + FLAGS.gcs_bucket + + "/" + + FLAGS.tensorboard_log_path + + timestamp + ) + else: + log_dir = FLAGS.tensorboard_log_path + timestamp + return tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) + + def main(_): - print(f"Reading input data from {FLAGS.input_files}") - input_filenames = list_filenames_for_arg(FLAGS.input_files) + if FLAGS.use_cloud_storage: + # If the job is on cloud, we will use cloud logging. + tf.keras.utils.disable_interactive_logging() + client = google.cloud.logging.Client() + client.setup_logging() + + logging.info(f"Reading input data from {FLAGS.input_files}") + if FLAGS.read_from_gcs: + input_filenames = list_blobs_with_prefix( + FLAGS.gcs_bucket, FLAGS.input_files + ) + else: + input_filenames = list_filenames_for_arg(FLAGS.input_files) + if not input_filenames: - print("No input files found. Check `input_files` flag.") + logging.info("No input files found. Check `input_files` flag.") sys.exit(1) vocab = [] @@ -385,15 +479,15 @@ def main(_): model_config = MODEL_CONFIGS[FLAGS.model_size] - if tf.config.list_logical_devices("TPU"): + if FLAGS.tpu_name is None: + # Use default strategy if not using TPU. + strategy = tf.distribute.get_strategy() + else: # Connect to TPU and create TPU strategy. resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( - tpu="local" + tpu=FLAGS.tpu_name ) strategy = tf.distribute.TPUStrategy(resolver) - else: - # Use default strategy if not using TPU. - strategy = tf.distribute.get_strategy() # Decode and batch data. dataset = tf.data.TFRecordDataset(input_filenames) @@ -437,23 +531,7 @@ def main(_): epochs = TRAINING_CONFIG["epochs"] steps_per_epoch = num_train_steps // epochs - callbacks = [] - if FLAGS.checkpoint_save_directory is not None: - if os.path.exists(FLAGS.checkpoint_save_directory): - if not os.path.isdir(FLAGS.checkpoint_save_directory): - raise ValueError( - "`checkpoint_save_directory` should be a directory, but " - f"{FLAGS.checkpoint_save_directory} is not a directory." - " Please set `checkpoint_save_directory` as a directory." - ) - - elif FLAGS.skip_restore: - # Clear up the directory if users want to skip restoring. - shutil.rmtree(FLAGS.checkpoint_save_directory) - checkpoint_path = FLAGS.checkpoint_save_directory + "/checkpoint" - callbacks.append( - tf.keras.callbacks.BackupAndRestore(backup_dir=checkpoint_path) - ) + callbacks = [get_checkpoint_callback(), get_tensorboard_callback()] pretraining_model.fit( dataset, @@ -462,8 +540,12 @@ def main(_): callbacks=callbacks, ) - print(f"Saving to {FLAGS.saved_model_output}") - model.save(FLAGS.saved_model_output) + if FLAGS.use_cloud_storage: + model_path = "gs://" + FLAGS.gcs_bucket + "/" + FLAGS.saved_model_output + else: + model_path = FLAGS.saved_model_output + logging.info(f"Saving to {FLAGS.saved_model_output}") + model.save(model_path) if __name__ == "__main__": diff --git a/examples/utils/google_cloud_utils.py b/examples/utils/google_cloud_utils.py new file mode 100644 index 0000000000..95f7c38f8b --- /dev/null +++ b/examples/utils/google_cloud_utils.py @@ -0,0 +1,57 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from google.cloud import storage + + +def list_blobs_with_prefix(bucket_name, prefix, delimiter=None): + """Lists all the blobs path in the bucket that begin with the prefix. + + This can be used to list all blobs in a "folder", e.g. "public/". + + The delimiter argument can be used to restrict the results to only the + "files" in the given "folder". Without the delimiter, the entire tree under + the prefix is returned. For example, given these blobs: + + a/1.txt + a/b/2.txt + + If you specify prefix ='a/', without a delimiter, you'll get back: + + a/1.txt + a/b/2.txt + + However, if you specify prefix='a/' and delimiter='/', you'll get back + only the file directly under 'a/': + + a/1.txt + + Args: + bucket_name: string, the cloud storage bucket name. + prefix: string, the prefix of the file path to look for blobs. + delimiter: string, the delimiter. + + Returns: + A list of GCS urls in the format "gs://bucket-name/file-path". + """ + + storage_client = storage.Client() + + blobs = storage_client.list_blobs( + bucket_name, prefix=prefix, delimiter=delimiter + ) + file_prefix = "gs://" + bucket_name + "/" + files = [] + for blob in blobs: + files.append(file_prefix + blob.name) + return files From 1c51bfd130ffac43974761e21af13a74e19a2b52 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 22 Jun 2022 16:37:32 -0700 Subject: [PATCH 2/3] Change how we support cloud training --- examples/bert/bert_config.py | 18 ++---- examples/bert/bert_model.py | 18 ++---- examples/bert/bert_train.py | 114 ++++++++++++----------------------- 3 files changed, 52 insertions(+), 98 deletions(-) diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py index 80334ae610..e5febb9254 100644 --- a/examples/bert/bert_config.py +++ b/examples/bert/bert_config.py @@ -16,9 +16,8 @@ "tiny": { "num_layers": 2, "hidden_size": 128, - "hidden_dropout": 0.1, + "dropout": 0.1, "num_attention_heads": 2, - "attention_dropout": 0.1, "inner_size": 512, "inner_activation": "gelu", "initializer_range": 0.02, @@ -26,9 +25,8 @@ "mini": { "num_layers": 4, "hidden_size": 256, - "hidden_dropout": 0.1, + "dropout": 0.1, "num_attention_heads": 4, - "attention_dropout": 0.1, "inner_size": 1024, "inner_activation": "gelu", "initializer_range": 0.02, @@ -36,9 +34,8 @@ "small": { "num_layers": 4, "hidden_size": 512, - "hidden_dropout": 0.1, + "dropout": 0.1, "num_attention_heads": 8, - "attention_dropout": 0.1, "inner_size": 2048, "inner_activation": "gelu", "initializer_range": 0.02, @@ -46,9 +43,8 @@ "medium": { "num_layers": 8, "hidden_size": 512, - "hidden_dropout": 0.1, + "dropout": 0.1, "num_attention_heads": 8, - "attention_dropout": 0.1, "inner_size": 2048, "inner_activation": "gelu", "initializer_range": 0.02, @@ -56,9 +52,8 @@ "base": { "num_layers": 12, "hidden_size": 768, - "hidden_dropout": 0.1, + "dropout": 0.1, "num_attention_heads": 12, - "attention_dropout": 0.1, "inner_size": 3072, "inner_activation": "gelu", "initializer_range": 0.02, @@ -66,9 +61,8 @@ "large": { "num_layers": 24, "hidden_size": 1024, - "hidden_dropout": 0.1, + "dropout": 0.1, "num_attention_heads": 16, - "attention_dropout": 0.1, "inner_size": 4096, "inner_activation": "gelu", "initializer_range": 0.02, diff --git a/examples/bert/bert_model.py b/examples/bert/bert_model.py index 237d00a77a..3027d8b3b5 100644 --- a/examples/bert/bert_model.py +++ b/examples/bert/bert_model.py @@ -141,12 +141,9 @@ class BertModel(keras.Model): vocab_size: The size of the token vocabulary. num_layers: The number of transformer layers. hidden_size: The size of the transformer hidden layers. - hidden_dropout: Dropout probability for the post-attention and output - dropout. + dropout: Dropout probability for the Transformer encoder. num_attention_heads: The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. - attention_dropout: The dropout rate to use for the attention layers - within the transformer layers. inner_size: The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. inner_activation: The activation for the first Dense layer in a @@ -169,9 +166,8 @@ def __init__( vocab_size, num_layers=12, hidden_size=768, - hidden_dropout=0.1, + dropout=0.1, num_attention_heads=12, - attention_dropout=0.1, inner_size=3072, inner_activation="gelu", initializer_range=0.02, @@ -193,8 +189,7 @@ def __init__( self.initializer = keras.initializers.TruncatedNormal( stddev=initializer_range ) - self.hidden_dropout = hidden_dropout - self.attention_dropout = attention_dropout + self.dropout = dropout self._embedding_layer = OnDeviceEmbedding( vocab_size=vocab_size, @@ -225,7 +220,7 @@ def __init__( ) self._embedding_dropout = keras.layers.Dropout( - rate=hidden_dropout, name="embedding_dropout" + rate=dropout, name="embedding_dropout" ) self._transformer_layers = [] @@ -234,7 +229,7 @@ def __init__( num_heads=num_attention_heads, intermediate_dim=inner_size, activation=self.inner_activation, - dropout=hidden_dropout, + dropout=dropout, kernel_initializer=self.initializer, name="transformer/layer_%d" % i, ) @@ -285,8 +280,7 @@ def get_config(self): "inner_activation": keras.activations.serialize( self.inner_activation ), - "hidden_dropout": self.hidden_dropout, - "attention_dropout": self.attention_dropout, + "dropout": self.dropout, "initializer_range": self.initializer_range, } ) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index e6ff8f2e35..f56705ece3 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -13,31 +13,26 @@ # limitations under the License. import datetime -import os -import shutil import sys -import google.cloud.logging import tensorflow as tf from absl import app from absl import flags from absl import logging -from google.cloud import storage from tensorflow import keras from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG from examples.bert.bert_model import BertModel -from examples.utils.google_cloud_utils import list_blobs_with_prefix -from examples.utils.scripting_utils import list_filenames_for_arg FLAGS = flags.FLAGS flags.DEFINE_string( - "input_files", + "data_directory", None, - "Comma seperated list of directories, globs or files.", + "The directory of training data. It can be a local disk path, or the URL " + "of Google cloud storage bucket.", ) flags.DEFINE_string( @@ -65,15 +60,9 @@ ) flags.DEFINE_bool( - "use_cloud_storage", + "enable_cloud_logging", False, - "If True, data I/O will use cloud storage instead of local disks.", -) - -flags.DEFINE_string( - "gcs_bucket", - None, - "Name of GCS bucket.", + "If True, the script will use cloud logging.", ) flags.DEFINE_string( @@ -401,79 +390,55 @@ def decode_record(record): def get_checkpoint_callback(): - if FLAGS.checkpoint_save_directory is not None: - if FLAGS.use_cloud_storage: - storage_client = storage.Client() - bucket = storage_client.get_bucket(FLAGS.gcs_bucket) - blobs = bucket.list_blobs(prefix=FLAGS.checkpoint_save_directory) - if FLAGS.skip_restore: - for blob in blobs: - blob.delete() - checkpoint_path = ( - "gs://" - + FLAGS.gcs_bucket - + "/" - + FLAGS.checkpoint_save_directory - ) - return tf.keras.callbacks.BackupAndRestore( - backup_dir=checkpoint_path, + if tf.io.gfile.exists(FLAGS.checkpoint_save_directory): + if not tf.io.gfile.isdir(FLAGS.checkpoint_save_directory): + raise ValueError( + "`checkpoint_save_directory` should be a directory, " + f"but {FLAGS.checkpoint_save_directory} is not a " + "directory. Please set `checkpoint_save_directory` as " + "a directory." ) - else: - if os.path.exists(FLAGS.checkpoint_save_directory): - if not os.path.isdir(FLAGS.checkpoint_save_directory): - raise ValueError( - "`checkpoint_save_directory` should be a directory, " - f"but {FLAGS.checkpoint_save_directory} is not a " - "directory. Please set `checkpoint_save_directory` as " - "a directory." - ) - - elif FLAGS.skip_restore: - # Clear up the directory if users want to skip restoring. - shutil.rmtree(FLAGS.checkpoint_save_directory) - checkpoint_path = FLAGS.checkpoint_save_directory - return tf.keras.callbacks.BackupAndRestore( - backup_dir=checkpoint_path, - ) + elif FLAGS.skip_restore: + # Clear up the directory if users want to skip restoring. + tf.io.gfile.rmtree(FLAGS.checkpoint_save_directory) + checkpoint_path = FLAGS.checkpoint_save_directory + return tf.keras.callbacks.BackupAndRestore( + backup_dir=checkpoint_path, + ) def get_tensorboard_callback(): timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - if FLAGS.use_cloud_storage: - log_dir = ( - "gs://" - + FLAGS.gcs_bucket - + "/" - + FLAGS.tensorboard_log_path - + timestamp - ) - else: - log_dir = FLAGS.tensorboard_log_path + timestamp + log_dir = FLAGS.tensorboard_log_path + timestamp return tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) def main(_): - if FLAGS.use_cloud_storage: + if FLAGS.enable_cloud_logging: # If the job is on cloud, we will use cloud logging. + import google.cloud.logging + tf.keras.utils.disable_interactive_logging() client = google.cloud.logging.Client() client.setup_logging() - logging.info(f"Reading input data from {FLAGS.input_files}") - if FLAGS.read_from_gcs: - input_filenames = list_blobs_with_prefix( - FLAGS.gcs_bucket, FLAGS.input_files + logging.info(f"Reading input data from {FLAGS.data_directory}") + if not tf.io.gfile.isdir(FLAGS.data_directory): + raise ValueError( + "`data_directory` should be a directory, " + f"but {FLAGS.data_directory} is not a directory. Please " + "set `data_directory` flag as a directory." ) - else: - input_filenames = list_filenames_for_arg(FLAGS.input_files) + files = tf.io.gfile.listdir(FLAGS.data_directory) + input_filenames = [FLAGS.data_directory + "/" + file for file in files] if not input_filenames: - logging.info("No input files found. Check `input_files` flag.") + logging.info("No input files found. Check `data_directory` flag.") sys.exit(1) vocab = [] - with open(FLAGS.vocab_file, "r") as vocab_file: + with tf.io.gfile.GFile(FLAGS.vocab_file) as vocab_file: for line in vocab_file: vocab.append(line.strip()) @@ -531,7 +496,11 @@ def main(_): epochs = TRAINING_CONFIG["epochs"] steps_per_epoch = num_train_steps // epochs - callbacks = [get_checkpoint_callback(), get_tensorboard_callback()] + callbacks = [] + if FLAGS.checkpoint_save_directory: + callbacks.append(get_checkpoint_callback()) + if FLAGS.tensorboard_log_path: + callbacks.append(get_tensorboard_callback()) pretraining_model.fit( dataset, @@ -540,16 +509,13 @@ def main(_): callbacks=callbacks, ) - if FLAGS.use_cloud_storage: - model_path = "gs://" + FLAGS.gcs_bucket + "/" + FLAGS.saved_model_output - else: - model_path = FLAGS.saved_model_output + model_path = FLAGS.saved_model_output logging.info(f"Saving to {FLAGS.saved_model_output}") model.save(model_path) if __name__ == "__main__": - flags.mark_flag_as_required("input_files") + flags.mark_flag_as_required("data_directory") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("saved_model_output") app.run(main) From 38b5c8087f8ac508de413d215c2f4f4138ca4168 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 24 Jun 2022 11:54:53 -0700 Subject: [PATCH 3/3] delete the unused util --- examples/bert/README.md | 11 ++++-- examples/bert/bert_train.py | 20 +++++----- examples/utils/google_cloud_utils.py | 57 ---------------------------- 3 files changed, 17 insertions(+), 71 deletions(-) delete mode 100644 examples/utils/google_cloud_utils.py diff --git a/examples/bert/README.md b/examples/bert/README.md index 93ccbafc19..3663763096 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -36,7 +36,7 @@ python3 examples/bert/bert_preprocess.py \ --output_file $OUTPUT_DIR/pretraining-data/pretraining.tfrecord # Run pretraining for 100 train steps only. python3 examples/bert/bert_train.py \ - --input_files $OUTPUT_DIR/pretraining-data/ \ + --input_directory $OUTPUT_DIR/pretraining-data/ \ --vocab_file $OUTPUT_DIR/bert_vocab_uncased.txt \ --saved_model_output $OUTPUT_DIR/model/ \ --num_train_steps 100 @@ -197,12 +197,14 @@ python3 -c "from examples.utils.data_utils import preview_tfrecord; preview_tfre After preprocessing, we can run pretraining with the `bert_train.py` script. This will train a model and save it to the `--saved_model_output` -directory. +directory. If you are willing to train from data stored on google cloud storage bucket (GCS), you can do it by setting the file path to +the URL of GCS bucket. For example, `--input_directory=gs://your-bucket-name/you-data-path`. You can also save models directly to GCS by the same approach. ```shell python3 examples/bert/bert_train.py \ - --input_files path/to/data/ \ + --input_directory path/to/data/ \ --vocab_file path/to/bert_vocab_uncased.txt \ + --model_size tiny \ --saved_model_output path/to/model/ ``` @@ -219,7 +221,8 @@ training for a few epochs to finetune the model. ```shell python3 examples/bert/bert_finetune_glue.py \ --saved_model_input path/to/model/ \ - --vocab_file path/to/bert_vocab_uncased.txt + --vocab_file path/to/bert_vocab_uncased.txt \ + --task_name mrpc ``` The script could be easily adapted to any other text classification finetuning diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index f56705ece3..27e234dacd 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -29,7 +29,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "data_directory", + "input_directory", None, "The directory of training data. It can be a local disk path, or the URL " "of Google cloud storage bucket.", @@ -423,18 +423,18 @@ def main(_): client = google.cloud.logging.Client() client.setup_logging() - logging.info(f"Reading input data from {FLAGS.data_directory}") - if not tf.io.gfile.isdir(FLAGS.data_directory): + logging.info(f"Reading input data from {FLAGS.input_directory}") + if not tf.io.gfile.isdir(FLAGS.input_directory): raise ValueError( - "`data_directory` should be a directory, " - f"but {FLAGS.data_directory} is not a directory. Please " - "set `data_directory` flag as a directory." + "`input_directory` should be a directory, " + f"but {FLAGS.input_directory} is not a directory. Please " + "set `input_directory` flag as a directory." ) - files = tf.io.gfile.listdir(FLAGS.data_directory) - input_filenames = [FLAGS.data_directory + "/" + file for file in files] + files = tf.io.gfile.listdir(FLAGS.input_directory) + input_filenames = [FLAGS.input_directory + "/" + file for file in files] if not input_filenames: - logging.info("No input files found. Check `data_directory` flag.") + logging.info("No input files found. Check `input_directory` flag.") sys.exit(1) vocab = [] @@ -515,7 +515,7 @@ def main(_): if __name__ == "__main__": - flags.mark_flag_as_required("data_directory") + flags.mark_flag_as_required("input_directory") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("saved_model_output") app.run(main) diff --git a/examples/utils/google_cloud_utils.py b/examples/utils/google_cloud_utils.py deleted file mode 100644 index 95f7c38f8b..0000000000 --- a/examples/utils/google_cloud_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2022 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from google.cloud import storage - - -def list_blobs_with_prefix(bucket_name, prefix, delimiter=None): - """Lists all the blobs path in the bucket that begin with the prefix. - - This can be used to list all blobs in a "folder", e.g. "public/". - - The delimiter argument can be used to restrict the results to only the - "files" in the given "folder". Without the delimiter, the entire tree under - the prefix is returned. For example, given these blobs: - - a/1.txt - a/b/2.txt - - If you specify prefix ='a/', without a delimiter, you'll get back: - - a/1.txt - a/b/2.txt - - However, if you specify prefix='a/' and delimiter='/', you'll get back - only the file directly under 'a/': - - a/1.txt - - Args: - bucket_name: string, the cloud storage bucket name. - prefix: string, the prefix of the file path to look for blobs. - delimiter: string, the delimiter. - - Returns: - A list of GCS urls in the format "gs://bucket-name/file-path". - """ - - storage_client = storage.Client() - - blobs = storage_client.list_blobs( - bucket_name, prefix=prefix, delimiter=delimiter - ) - file_prefix = "gs://" + bucket_name + "/" - files = [] - for blob in blobs: - files.append(file_prefix + blob.name) - return files