diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000..7acc255cfd --- /dev/null +++ b/examples/README.md @@ -0,0 +1,60 @@ +# KerasNLP Example Models + +This examples directory contains complete recipes for training popular model +architectures using KerasNLP. These are not part of the library itself, but +rather serve to demonstrate how to use the library for common tasks, while +simultaneously providing a mechanism to rigorously test library components. + +This directory is complementary to the NLP examples on +[keras.io](https://keras.io/examples/). If you want to contribute a KerasNLP +example and you're not sure if it should live on keras.io or in this directory, +here's how they differ: + +- If an example can fit in < 300 lines of code and run in a Colab, + put it on keras.io. +- If an example is too big for a single script or has high compute requirements, + add it here. + +In general, we will have a fairly high bar for what models we support in this +directory. They should be widely used, practical models for solving standard +problems in NLP. + +## Anatomy of an example + +Given a model named `modelname`, which involves both pretraining and finetuning +on a downstream task, the contents of the `modelname` directory should be as +follows: + +```shell +modelname +├── README.md +├── __init__.py +├── modelname_config.py +├── modelname_model.py +├── modelname_preprocess.py +├── modelname_train.py +└── modelname_finetune_X.py +``` + +- `README.md`: The README should contain complete instructions for downloading + data and training a model from scratch. +- `__init__.py`: Empty (it's for imports). +- `modelname_config.py`: This file should contain most of the configuration for + the model architecture, learning rate, etc, using simple Python constants. We + would like to avoid complex configuration setups (json, yaml, etc). +- `modelname_preprocess.py`: If necessary. Standalone script to preprocess + inputs. If possible, prefer doing preprocessing dynamically with tf.data + inside the training and finetuning scripts. +- `modelname_model.py`: This file should contain the actual `keras.Model` and + any custom layers needed for the example. Use KerasNLP components where ever + possible. +- `modelname_train.py`: This file should be a runnable training script for + pretraining. If possible, this script should preprocess data dynamically + during training using `tf.data` and KerasNLP components (e.g. tokenizers). +- `modelname_finetune_X.py`: Optional. There can be any number of these files, + for each task `X` we would like to support for finetuning. The file should be + a runnable training script which loads and finetunes a pretrained model. + +## Instructions for running on Google Cloud + +TODO(https://github.com/keras-team/keras-nlp/issues/178) diff --git a/examples/bert/README.md b/examples/bert/README.md index 7913ea68a1..02dc241ecf 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -1,14 +1,14 @@ -# BERT with keras-nlp +# BERT with KerasNLP -This example will show how to train a Bidirectional Encoder -Representations from Transformers (BERT) model end-to-end using the keras-nlp +This example demonstrates how to train a Bidirectional Encoder +Representations from Transformers (BERT) model end-to-end using the KerasNLP library. This README contains instructions on how to run pretraining directly -from raw data, followed by fine tuning and evaluation on the GLUE dataset. +from raw data, followed by finetuning and evaluation on the GLUE dataset. ## Quickly test out the code To exercise the code in this directory by training a tiny BERT model, you can -run the following commands from the base of the keras-nlp repository. This can +run the following commands from the base directory of the repository. This can be useful to validate any code changes, but note that a useful BERT model would need to be trained for much longer on a much larger dataset. @@ -18,47 +18,42 @@ DATA_URL=https://storage.googleapis.com/tensorflow/keras-nlp/examples/bert # Create a virtual env and install dependencies. mkdir $OUTPUT_DIR -python3 -m venv $OUTPUT_DIR/env -source $OUTPUT_DIR/env/bin/activate +python3 -m venv $OUTPUT_DIR/env && source $OUTPUT_DIR/env/bin/activate pip install -e ".[tests,examples]" # Download example data. wget ${DATA_URL}/bert_vocab_uncased.txt -O $OUTPUT_DIR/bert_vocab_uncased.txt wget ${DATA_URL}/wiki_example_data.txt -O $OUTPUT_DIR/wiki_example_data.txt -# Run preprocessing. -python3 examples/bert/create_sentence_split_data.py \ +# Parse input data and split into sentences. +python3 examples/tools/split_sentences.py \ --input_files $OUTPUT_DIR/wiki_example_data.txt \ - --output_directory $OUTPUT_DIR/sentence-split-data --num_shards 1 -python3 examples/bert/create_pretraining_data.py \ + --output_directory $OUTPUT_DIR/sentence-split-data +# Preprocess input for pretraining. +python3 examples/bert/bert_preprocess.py \ --input_files $OUTPUT_DIR/sentence-split-data/ \ --vocab_file $OUTPUT_DIR/bert_vocab_uncased.txt \ --output_file $OUTPUT_DIR/pretraining-data/pretraining.tfrecord - -# Run pretraining. -python3 examples/bert/run_pretraining.py \ +# Run pretraining for 100 train steps only. +python3 examples/bert/bert_train.py \ --input_files $OUTPUT_DIR/pretraining-data/ \ --vocab_file $OUTPUT_DIR/bert_vocab_uncased.txt \ - --bert_config_file examples/bert/configs/bert_tiny.json \ - --num_warmup_steps 20 \ - --num_train_steps 200 \ - --saved_model_output $OUTPUT_DIR/model/ - + --saved_model_output $OUTPUT_DIR/model/ \ + --num_train_steps 100 # Run finetuning. -python3 examples/bert/run_glue_finetuning.py \ +python3 examples/bert/bert_finetune_glue.py \ --saved_model_input $OUTPUT_DIR/model/ \ - --vocab_file $OUTPUT_DIR/bert_vocab_uncased.txt \ - --bert_config_file examples/bert/configs/bert_tiny.json + --vocab_file $OUTPUT_DIR/bert_vocab_uncased.txt ``` ## Installing dependencies -Pip dependencies for all keras-nlp examples are listed in `setup.py`. To install -both the keras-nlp library from source and all other dependencies required to -run the example, run the below command. You may want to install to a self -contained environment (e.g. a container or a virtualenv). +Pip dependencies for all KerasNLP examples are listed in `setup.py`. The +following command will create a virtual environment, install all dependencies, +and install KerasNLP from source. ```shell +python3 -m venv path/to/venv && source path/to/venv/bin/activate pip install -e ".[examples]" ``` @@ -66,7 +61,7 @@ pip install -e ".[examples]" Training a BERT model happens in two stages. First, the model is "pretrained" on a large corpus of input text. This is computationally expensive. After -pretraining, the model can be "fine tuned" on a downstream task with much +pretraining, the model can be "finetuned" on a downstream task with a much smaller amount of labeled data. ### Downloading pretraining data @@ -75,7 +70,8 @@ The GLUE pretraining data (Wikipedia + BooksCorpus) is fairly large. The raw input data takes roughly ~20GB of space, and after preprocessing, the full corpus will take ~400GB. -The latest wikipedia dump can be downloaded [at this link](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), +The latest wikipedia dump can be downloaded +[at this link](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), or via command line: ```shell @@ -84,14 +80,14 @@ curl -O https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.x The dump can be extracted with the `wikiextractor` tool. ```shell -python -m wikiextractor.WikiExtractor enwiki-latest-pages-articles.xml.bz2 +python3 -m wikiextractor.WikiExtractor enwiki-latest-pages-articles.xml.bz2 ``` BooksCorpus is no longer hosted by -[it's creators](https://yknzhu.wixsite.com/mbweb), but you can find instructions -for downloading or reproducing the corpus in this -[repository](https://github.com/soskek/bookcorpus). We suggest the pre-made file -downloads listed at the top of the README. Alternatively, you can forgo it +[its creators](https://yknzhu.wixsite.com/mbweb), but you can find instructions +for downloading or reproducing the corpus in +[this repository](https://github.com/soskek/bookcorpus). We suggest the pre-made file +downloads listed at the top of the README. Alternatively, you can forgo it entirely and pretrain solely on wikipedia. Preparing the pretraining data will happen in two stages. First, raw text needs @@ -101,31 +97,27 @@ next sentence predictions. ### Splitting raw text into sentences -The `create_sentence_split_data.py` will process raw input files and split them -into output files where each line contains a sentence, and a blank line marks -the start of a new document. +Next, use `examples/tools/split_sentences.py` to process raw input files and +split them into output files where each line contains a sentence, and a blank +line marks the start of a new document. We need this for the next-sentence +prediction task used by BERT. -The script supports two types of inputs files. Plain text files, where each -individual file is assumed to be an entire document, and wikipedia dump files -in the format outputted by the wikiextractor tool (each document is enclosed in -`` tags). - -For example, if wikipedia files are located in `~/datasets/wikipedia` and +For example, if Wikipedia files are located in `~/datasets/wikipedia` and bookscorpus in `~/datasets/bookscorpus`, the following command will output sentence split documents to a configurable number of output file shards: ```shell -python examples/bert/create_sentence_split_data.py \ +python3 examples/tools/split_sentences.py \ --input_files ~/datasets/wikipedia,~/datasets/bookscorpus \ --output_directory ~/datasets/sentence-split-data ``` ### Computing a WordPiece vocabulary -The `create_vocabulary.py` script allows you to compute your own WordPiece -vocabulary for use with BERT. In most cases however, it is desirable to use the -standard BERT vocabularies from the original models. You can download the -English uncased vocabulary +The easiest and best approach when training BERT is to use the official +vocabularies from the original project, which have become somewhat standard. + +You can download the English uncased vocabulary [here](https://storage.googleapis.com/tensorflow/keras-nlp/examples/bert/bert_vocab_uncased.txt), or in your terminal run: @@ -133,11 +125,12 @@ or in your terminal run: curl -O https://storage.googleapis.com/tensorflow/keras-nlp/examples/bert/bert_vocab_uncased.txt ``` +You can also use `examples/tools/train_word_piece_vocab.py` to train your own. + ### Tokenize, mask, and combine sentences into training examples -The `create_pretraining_data.py` scrip will take in a set of sentence split -files, and set up training examples for the next sentence prediction and masked -word tasks. +The `bert_preprocess.py` script will take in a set of sentence split files, and +set up training examples for the next sentence prediction and masked word tasks. The output of the script will be TFRecord files with a number of fields per example. Below shows a complete output example with the addition of a string @@ -172,21 +165,22 @@ with the following: ```shell for file in path/to/sentence-split-data/*; do output="path/to/pretraining-data/$(basename -- "$file" .txt).tfrecord" - python examples/bert/create_pretraining_data.py \ + python3 examples/bert/bert_preprocess.py \ --input_files ${file} \ --vocab_file bert_vocab_uncased.txt \ --output_file ${output} done ``` -If memory is available, this could be further sped up by running this script -multiple times in parallel: +If enough memory is available, this could be further sped up by running this script +multiple times in parallel. The following will take 3-4 hours on the entire dataset +on an 8 core machine. ```shell NUM_JOBS=5 for file in path/to/sentence-split-data/*; do output="path/to/pretraining-data/$(basename -- "$file" .txt).tfrecord" - echo python examples/bert/create_pretraining_data.py \ + echo python3 examples/bert/bert_preprocess.py \ --input_files ${file} \ --vocab_file bert_vocab_uncased.txt \ --output_file ${output} @@ -196,17 +190,17 @@ done | parallel -j ${NUM_JOBS} To preview a sample of generated data files, you can run the command below: ```shell -python -c "from keras_nlp.utils.tensor_utils import preview_tfrecord; preview_tfrecord('/path/to/tfrecord_file')" +python3 -c "from examples.utils.data_utils import preview_tfrecord; preview_tfrecord('path/to/tfrecord_file')" ``` ### Running BERT pretraining -After preprocessing, we can run pretraining with the `run_pretraining.py` +After preprocessing, we can run pretraining with the `bert_train.py` script. This will train a model and save it to the `--saved_model_output` directory. ```shell -python3 examples/bert/run_pretraining.py \ +python3 examples/bert/bert_train.py \ --input_files path/to/data/ \ --vocab_file path/to/bert_vocab_uncased.txt \ --bert_config_file examples/bert/configs/bert_tiny.json \ @@ -217,18 +211,18 @@ python3 examples/bert/run_pretraining.py \ After pretraining, we can evaluate the performance of a BERT model with the General Language Understanding Evaluation (GLUE) benchmark. This will -fine tune the model and running classification for a number of downstream tasks. +finetune the model and running classification for a number of downstream tasks. -The `run_glue_finetuning.py` script downloads the GLUE data for a specific +The `bert_finetune_glue.py` script downloads the GLUE data for a specific tasks, reloads the pretraining model with appropriate finetuning heads, and runs training for a few epochs to finetune the model. ```shell -python3 examples/bert/run_glue_finetuning.py \ +python3 examples/bert/bert_finetune_glue.py \ --saved_model_input path/to/model/ \ --vocab_file path/to/bert_vocab_uncased.txt \ --bert_config_file examples/bert/configs/bert_tiny.json \ ``` -The script could be easily adapted to any other text classification fine-tuning +The script could be easily adapted to any other text classification finetuning tasks, where inputs can be any number of raw text sentences per sample. diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py new file mode 100644 index 0000000000..80334ae610 --- /dev/null +++ b/examples/bert/bert_config.py @@ -0,0 +1,102 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MODEL_CONFIGS = { + "tiny": { + "num_layers": 2, + "hidden_size": 128, + "hidden_dropout": 0.1, + "num_attention_heads": 2, + "attention_dropout": 0.1, + "inner_size": 512, + "inner_activation": "gelu", + "initializer_range": 0.02, + }, + "mini": { + "num_layers": 4, + "hidden_size": 256, + "hidden_dropout": 0.1, + "num_attention_heads": 4, + "attention_dropout": 0.1, + "inner_size": 1024, + "inner_activation": "gelu", + "initializer_range": 0.02, + }, + "small": { + "num_layers": 4, + "hidden_size": 512, + "hidden_dropout": 0.1, + "num_attention_heads": 8, + "attention_dropout": 0.1, + "inner_size": 2048, + "inner_activation": "gelu", + "initializer_range": 0.02, + }, + "medium": { + "num_layers": 8, + "hidden_size": 512, + "hidden_dropout": 0.1, + "num_attention_heads": 8, + "attention_dropout": 0.1, + "inner_size": 2048, + "inner_activation": "gelu", + "initializer_range": 0.02, + }, + "base": { + "num_layers": 12, + "hidden_size": 768, + "hidden_dropout": 0.1, + "num_attention_heads": 12, + "attention_dropout": 0.1, + "inner_size": 3072, + "inner_activation": "gelu", + "initializer_range": 0.02, + }, + "large": { + "num_layers": 24, + "hidden_size": 1024, + "hidden_dropout": 0.1, + "num_attention_heads": 16, + "attention_dropout": 0.1, + "inner_size": 4096, + "inner_activation": "gelu", + "initializer_range": 0.02, + }, +} + +# Currently we have the same set of training parameters for all configurations. +# We should see if we need to split this for different architecture sizes. + +PREPROCESSING_CONFIG = { + "max_seq_length": 512, + "max_predictions_per_seq": 76, + "dupe_factor": 10, + "masked_lm_prob": 0.15, + "short_seq_prob": 0.1, +} + +TRAINING_CONFIG = { + "batch_size": 256, + "epochs": 10, + "learning_rate": 1e-4, + "num_train_steps": 1_000_000, + # Percentage of training steps used for learning rate warmup. + "warmup_percentage": 0.1, +} + +FINETUNING_CONFIG = { + "batch_size": 32, + "epochs": 3, + "learning_rates": [5e-5, 4e-5, 3e-5, 2e-5], +} diff --git a/examples/bert/run_glue_finetuning.py b/examples/bert/bert_finetune_glue.py similarity index 82% rename from examples/bert/run_glue_finetuning.py rename to examples/bert/bert_finetune_glue.py index d15ecbbb29..b7e378d657 100644 --- a/examples/bert/run_glue_finetuning.py +++ b/examples/bert/bert_finetune_glue.py @@ -13,8 +13,6 @@ # limitations under the License. """Run finetuning on a GLUE task.""" -import json - import datasets import keras_tuner import tensorflow as tf @@ -23,54 +21,54 @@ from absl import flags from tensorflow import keras +from examples.bert.bert_config import FINETUNING_CONFIG +from examples.bert.bert_config import MODEL_CONFIGS +from examples.bert.bert_config import PREPROCESSING_CONFIG + FLAGS = flags.FLAGS flags.DEFINE_string( - "bert_config_file", - None, - "The json config file for the bert model parameters.", + "model_size", + "tiny", + "One of: tiny, mini, small, medium, base, or large.", ) flags.DEFINE_string( "vocab_file", None, - "The vocabulary file that the BERT model was trained on.", + "The vocabulary file for tokenization.", ) flags.DEFINE_string( "saved_model_input", None, - "The directory containing the input pretrained model to finetune.", + "The directory to load the pretrained model.", ) flags.DEFINE_string( - "saved_model_output", None, "The directory to save the finetuned model in." + "saved_model_output", + None, + "The directory to save the finetuned model.", ) - flags.DEFINE_string( - "task_name", "mrpc", "The name of the GLUE task to finetune on." + "task_name", + "mrpc", + "The name of the GLUE task to finetune on.", ) flags.DEFINE_bool( "do_lower_case", True, - "Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.", + "Whether to lower case the input text.", ) flags.DEFINE_bool( "do_evaluation", True, - "Whether to run evaluation on the validation set for a given task.", + "Whether to run evaluation on test data.", ) -flags.DEFINE_integer("batch_size", 32, "The batch size.") - -flags.DEFINE_integer("epochs", 3, "The number of training epochs.") - -flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") - def pack_inputs( inputs, @@ -173,23 +171,25 @@ def call(self, inputs): class BertHyperModel(keras_tuner.HyperModel): """Creates a hypermodel to help with the search space for finetuning.""" - def __init__(self, bert_config): - self.bert_config = bert_config + def __init__(self, model_config): + self.model_config = model_config def build(self, hp): model = keras.models.load_model(FLAGS.saved_model_input, compile=False) - bert_config = self.bert_config + model_config = self.model_config finetuning_model = BertClassificationFinetuner( bert_model=model, - hidden_size=bert_config["hidden_size"], + hidden_size=model_config["hidden_size"], num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2, initializer=keras.initializers.TruncatedNormal( - stddev=bert_config["initializer_range"] + stddev=model_config["initializer_range"] ), ) finetuning_model.compile( optimizer=keras.optimizers.Adam( - learning_rate=hp.Choice("lr", [5e-5, 4e-5, 3e-5, 2e-5]) + learning_rate=hp.Choice( + "lr", FINETUNING_CONFIG["learning_rates"] + ), ), loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[keras.metrics.SparseCategoricalAccuracy()], @@ -213,14 +213,13 @@ def main(_): end_id = vocab.index("[SEP]") pad_id = vocab.index("[PAD]") - with open(FLAGS.bert_config_file, "r") as bert_config_file: - bert_config = json.loads(bert_config_file.read()) + model_config = MODEL_CONFIGS[FLAGS.model_size] def preprocess_data(inputs, labels): inputs = [tokenizer.tokenize(x).merge_dims(1, -1) for x in inputs] inputs = pack_inputs( inputs, - FLAGS.max_seq_length, + PREPROCESSING_CONFIG["max_seq_length"], start_of_sequence_id=start_id, end_of_segment_id=end_id, padding_id=pad_id, @@ -230,18 +229,19 @@ def preprocess_data(inputs, labels): # Read and preprocess GLUE task data. train_ds, test_ds, validation_ds = load_data(FLAGS.task_name) - train_ds = train_ds.batch(FLAGS.batch_size).map( + batch_size = FINETUNING_CONFIG["batch_size"] + train_ds = train_ds.batch(batch_size).map( preprocess_data, num_parallel_calls=tf.data.AUTOTUNE ) - validation_ds = validation_ds.batch(FLAGS.batch_size).map( + validation_ds = validation_ds.batch(batch_size).map( preprocess_data, num_parallel_calls=tf.data.AUTOTUNE ) - test_ds = test_ds.batch(FLAGS.batch_size).map( + test_ds = test_ds.batch(batch_size).map( preprocess_data, num_parallel_calls=tf.data.AUTOTUNE ) # Create a hypermodel object for a RandomSearch. - hypermodel = BertHyperModel(bert_config) + hypermodel = BertHyperModel(model_config) # Initialize the random search over the 4 learning rate parameters, for 4 # trials and 3 epochs for each trial. @@ -253,7 +253,11 @@ def preprocess_data(inputs, labels): project_name="hyperparameter_tuner_results", ) - tuner.search(train_ds, epochs=FLAGS.epochs, validation_data=validation_ds) + tuner.search( + train_ds, + epochs=FINETUNING_CONFIG["epochs"], + validation_data=validation_ds, + ) # Extract the best hyperparameters after the search. best_hp = tuner.get_best_hyperparameters()[0] @@ -276,6 +280,5 @@ def preprocess_data(inputs, labels): if __name__ == "__main__": flags.mark_flag_as_required("vocab_file") - flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("saved_model_input") app.run(main) diff --git a/examples/bert/create_pretraining_data.py b/examples/bert/bert_preprocess.py similarity index 93% rename from examples/bert/create_pretraining_data.py rename to examples/bert/bert_preprocess.py index 7bf07d37d9..36b6bc0e27 100644 --- a/examples/bert/create_pretraining_data.py +++ b/examples/bert/bert_preprocess.py @@ -44,7 +44,8 @@ from absl import app from absl import flags -from examples.bert.bert_utils import list_filenames_for_arg +from examples.bert.bert_config import PREPROCESSING_CONFIG +from examples.utils.scripting_utils import list_filenames_for_arg # Tokenization will happen with tensorflow and can easily OOM a GPU. # Restrict the script to run CPU as GPU will not offer speedup here anyway. @@ -55,7 +56,7 @@ flags.DEFINE_string( "input_files", None, - "Comma seperated list of directories, files, or globs.", + "Comma seperated list of directories, globs or files.", ) flags.DEFINE_string( @@ -67,39 +68,19 @@ flags.DEFINE_string( "vocab_file", None, - "The vocabulary file that the BERT model was trained on.", + "The vocabulary file for tokenization.", ) flags.DEFINE_bool( "do_lower_case", True, - "Whether to lower case the input text. Should be True for uncased " - "models and False for cased models.", + "Whether to lower case the input text.", ) -flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") - -flags.DEFINE_integer( - "max_predictions_per_seq", - 20, - "Maximum number of masked LM predictions per sequence.", -) - -flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") - flags.DEFINE_integer( - "dupe_factor", - 10, - "Number of times to duplicate the input data (with different masks).", -) - -flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") - -flags.DEFINE_float( - "short_seq_prob", - 0.1, - "Probability of creating sequences which are shorter than the " - "maximum length.", + "random_seed", + 12345, + "Random seed for data generation.", ) @@ -500,11 +481,11 @@ def main(_): input_filenames, tokenizer, vocab, - FLAGS.max_seq_length, - FLAGS.dupe_factor, - FLAGS.short_seq_prob, - FLAGS.masked_lm_prob, - FLAGS.max_predictions_per_seq, + PREPROCESSING_CONFIG["max_seq_length"], + PREPROCESSING_CONFIG["dupe_factor"], + PREPROCESSING_CONFIG["short_seq_prob"], + PREPROCESSING_CONFIG["masked_lm_prob"], + PREPROCESSING_CONFIG["max_predictions_per_seq"], rng, ) @@ -515,8 +496,8 @@ def main(_): write_instance_to_example_files( instances, vocab, - FLAGS.max_seq_length, - FLAGS.max_predictions_per_seq, + PREPROCESSING_CONFIG["max_seq_length"], + PREPROCESSING_CONFIG["max_predictions_per_seq"], FLAGS.output_file, ) diff --git a/examples/bert/run_pretraining.py b/examples/bert/bert_train.py similarity index 88% rename from examples/bert/run_pretraining.py rename to examples/bert/bert_train.py index 66768b023e..ce831a303d 100644 --- a/examples/bert/run_pretraining.py +++ b/examples/bert/bert_train.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import sys import tensorflow as tf @@ -20,58 +19,42 @@ from absl import flags from tensorflow import keras +from examples.bert.bert_config import MODEL_CONFIGS +from examples.bert.bert_config import PREPROCESSING_CONFIG +from examples.bert.bert_config import TRAINING_CONFIG from examples.bert.bert_model import BertModel -from examples.bert.bert_utils import list_filenames_for_arg +from examples.utils.scripting_utils import list_filenames_for_arg FLAGS = flags.FLAGS flags.DEFINE_string( "input_files", None, - "Comma seperated list of directories, files, or globs for input data.", + "Comma seperated list of directories, globs or files.", ) flags.DEFINE_string( - "saved_model_output", None, "Output directory to save the model to." + "saved_model_output", + None, + "Output directory to save the model to.", ) flags.DEFINE_string( - "bert_config_file", - None, - "The json config file for the bert model parameters.", + "model_size", + "tiny", + "One of: tiny, mini, small, medium, base, or large.", ) flags.DEFINE_string( "vocab_file", None, - "The vocabulary file that the BERT model was trained on.", -) - -flags.DEFINE_integer("epochs", 10, "The number of training epochs.") - -flags.DEFINE_integer("batch_size", 256, "The training batch size.") - -flags.DEFINE_float("learning_rate", 1e-4, "The initial learning rate for Adam.") - -flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") - -flags.DEFINE_integer( - "max_predictions_per_seq", - 20, - "Maximum number of masked LM predictions per sequence.", -) - -flags.DEFINE_integer( - "num_warmup_steps", - 10000, - "The number of warmup steps during which the learning rate will increase " - "till a threshold.", + "The vocabulary file for tokenization.", ) flags.DEFINE_integer( "num_train_steps", - 1000000, - "The total fixed number of steps till which the model will train.", + None, + "Override the pre-configured number of train steps..", ) @@ -345,8 +328,8 @@ def __call__(self, step): def decode_record(record): """Decodes a record to a TensorFlow example.""" - seq_length = FLAGS.max_seq_length - lm_length = FLAGS.max_predictions_per_seq + seq_length = PREPROCESSING_CONFIG["max_seq_length"] + lm_length = PREPROCESSING_CONFIG["max_predictions_per_seq"] name_to_features = { "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64), @@ -379,8 +362,7 @@ def main(_): for line in vocab_file: vocab.append(line.strip()) - with open(FLAGS.bert_config_file, "r") as bert_config_file: - bert_config = json.loads(bert_config_file.read()) + model_config = MODEL_CONFIGS[FLAGS.model_size] # Decode and batch data. dataset = tf.data.TFRecordDataset(input_filenames) @@ -388,22 +370,33 @@ def main(_): lambda record: decode_record(record), num_parallel_calls=tf.data.experimental.AUTOTUNE, ) - dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.batch(TRAINING_CONFIG["batch_size"], drop_remainder=True) dataset = dataset.repeat() # Create a BERT model the input config. model = BertModel( vocab_size=len(vocab), - **bert_config, + **model_config, ) # Make sure model has been called. model(model.inputs) model.summary() + # Allow overriding train steps from the command line for quick testing. + if FLAGS.num_train_steps is not None: + num_train_steps = FLAGS.num_train_steps + else: + num_train_steps = TRAINING_CONFIG["num_train_steps"] + num_warmup_steps = int( + num_train_steps * TRAINING_CONFIG["warmup_percentage"] + ) + epochs = TRAINING_CONFIG["epochs"] + steps_per_epoch = num_train_steps // epochs + learning_rate_schedule = LinearDecayWithWarmup( - learning_rate=FLAGS.learning_rate, - num_warmup_steps=FLAGS.num_warmup_steps, - num_train_steps=FLAGS.num_train_steps, + learning_rate=TRAINING_CONFIG["learning_rate"], + num_warmup_steps=num_warmup_steps, + num_train_steps=num_train_steps, ) # Wrap with pretraining heads and call fit. @@ -412,9 +405,8 @@ def main(_): optimizer=keras.optimizers.Adam(learning_rate=learning_rate_schedule) ) # TODO(mattdangerw): Add TPU strategy support. - steps_per_epoch = FLAGS.num_train_steps // FLAGS.epochs pretraining_model.fit( - dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch + dataset, epochs=epochs, steps_per_epoch=steps_per_epoch ) print(f"Saving to {FLAGS.saved_model_output}") @@ -424,6 +416,5 @@ def main(_): if __name__ == "__main__": flags.mark_flag_as_required("input_files") flags.mark_flag_as_required("vocab_file") - flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("saved_model_output") app.run(main) diff --git a/examples/bert/configs/bert_base.json b/examples/bert/configs/bert_base.json deleted file mode 100644 index 80eb26216a..0000000000 --- a/examples/bert/configs/bert_base.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "num_layers": 12, - "hidden_size": 768, - "hidden_dropout": 0.1, - "num_attention_heads": 12, - "attention_dropout": 0.1, - "inner_size": 3072, - "inner_activation": "gelu", - "initializer_range": 0.02, - "max_sequence_length": 512 -} diff --git a/examples/bert/configs/bert_large.json b/examples/bert/configs/bert_large.json deleted file mode 100644 index 1638feb049..0000000000 --- a/examples/bert/configs/bert_large.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "num_layers": 24, - "hidden_size": 1024, - "hidden_dropout": 0.1, - "num_attention_heads": 16, - "attention_dropout": 0.1, - "inner_size": 4096, - "inner_activation": "gelu", - "initializer_range": 0.02, - "max_sequence_length": 512 -} diff --git a/examples/bert/configs/bert_medium.json b/examples/bert/configs/bert_medium.json deleted file mode 100644 index af12f066b3..0000000000 --- a/examples/bert/configs/bert_medium.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "num_layers": 8, - "hidden_size": 512, - "hidden_dropout": 0.1, - "num_attention_heads": 8, - "attention_dropout": 0.1, - "inner_size": 2048, - "inner_activation": "gelu", - "initializer_range": 0.02, - "max_sequence_length": 512 -} diff --git a/examples/bert/configs/bert_mini.json b/examples/bert/configs/bert_mini.json deleted file mode 100644 index 8fd4fa7e5e..0000000000 --- a/examples/bert/configs/bert_mini.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "num_layers": 4, - "hidden_size": 256, - "hidden_dropout": 0.1, - "num_attention_heads": 4, - "attention_dropout": 0.1, - "inner_size": 1024, - "inner_activation": "gelu", - "initializer_range": 0.02, - "max_sequence_length": 512 -} diff --git a/examples/bert/configs/bert_small.json b/examples/bert/configs/bert_small.json deleted file mode 100644 index 1d744e87d8..0000000000 --- a/examples/bert/configs/bert_small.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "num_layers": 4, - "hidden_size": 512, - "hidden_dropout": 0.1, - "num_attention_heads": 8, - "attention_dropout": 0.1, - "inner_size": 2048, - "inner_activation": "gelu", - "initializer_range": 0.02, - "max_sequence_length": 512 -} diff --git a/examples/bert/configs/bert_tiny.json b/examples/bert/configs/bert_tiny.json deleted file mode 100644 index 3521a015df..0000000000 --- a/examples/bert/configs/bert_tiny.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "num_layers": 2, - "hidden_size": 128, - "hidden_dropout": 0.1, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "inner_size": 512, - "inner_activation": "gelu", - "initializer_range": 0.02, - "max_sequence_length": 512 -} diff --git a/examples/tools/README.md b/examples/tools/README.md new file mode 100644 index 0000000000..efe525e96f --- /dev/null +++ b/examples/tools/README.md @@ -0,0 +1,37 @@ +# KerasNLP Modeling Tools + +This directory contains runnable scripts that are not specific to a specific +model architecture, but still useful for end-to-end workflows. + +## split_sentences.py + +The `split_sentences.py` script will process raw input files and split them into +output files where each line contains a sentence, and a blank line marks the +start of a new document. This is useful for tasks like next sentence prediction +where the boundaries between sentences are needed for training. + +The script supports two types of inputs files. Plain text files, where each +individual file is assumed to be an entire document, and wikipedia dump files +in the format outputted by the wikiextractor tool (each document is enclosed in +`` tags). + +Example usage: + +```shell +python examples/tools/split_sentences.py \ + --input_files ~/datasets/wikipedia,~/datasets/bookscorpus \ + --output_directory ~/datasets/sentence-split-data +``` + +### train_word_piece_vocabulary.py + +The `train_word_piece_vocabulary.py` script allows you to compute your own +WordPiece vocabulary. + +Example usage: + +```shell +python examples/tools/train_word_piece_vocabulary.py \ + --input_files ~/datasets/my-raw-dataset/ \ + --output_file vocab.txt +``` diff --git a/examples/tools/__init__.py b/examples/tools/__init__.py new file mode 100644 index 0000000000..6e4df4e727 --- /dev/null +++ b/examples/tools/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/bert/create_sentence_split_data.py b/examples/tools/split_sentences.py similarity index 97% rename from examples/bert/create_sentence_split_data.py rename to examples/tools/split_sentences.py index e11818a37f..373ab6a498 100644 --- a/examples/bert/create_sentence_split_data.py +++ b/examples/tools/split_sentences.py @@ -22,7 +22,7 @@ output file shards can be controlled with `--num_jobs` and `--num_shards`. Usage: -python create_sentence_split_data.py \ +python examples/tools/create_sentence_split_data.py \ --input_files ~/datasets/wikipedia,~/datasets/bookscorpus \ --output_directory ~/datasets/bert-sentence-split-data """ @@ -38,7 +38,7 @@ from absl import flags from tensorflow import keras -from examples.bert.bert_utils import list_filenames_for_arg +from examples.utils.scripting_utils import list_filenames_for_arg FLAGS = flags.FLAGS diff --git a/examples/bert/create_vocabulary.py b/examples/tools/train_word_piece_vocab.py similarity index 96% rename from examples/bert/create_vocabulary.py rename to examples/tools/train_word_piece_vocab.py index 34529f6a7e..d5cbff169c 100644 --- a/examples/bert/create_vocabulary.py +++ b/examples/tools/train_word_piece_vocab.py @@ -16,7 +16,7 @@ This script will create wordpiece vocabularies suitable for pretraining BERT. Usage: -python create_vocabulary.py \ +python examples/tools/train_word_piece_vocabulary.py \ --input_files ~/datasets/bert-sentence-split-data/ \ --output_file vocab.txt """ @@ -29,7 +29,7 @@ from absl import flags from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset -from examples.bert.bert_utils import list_filenames_for_arg +from examples.utils.scripting_utils import list_filenames_for_arg FLAGS = flags.FLAGS diff --git a/examples/utils/__init__.py b/examples/utils/__init__.py new file mode 100644 index 0000000000..6e4df4e727 --- /dev/null +++ b/examples/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/utils/data_utils.py b/examples/utils/data_utils.py new file mode 100644 index 0000000000..b9f50af6ff --- /dev/null +++ b/examples/utils/data_utils.py @@ -0,0 +1,30 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for data handling.""" + +import os + +import tensorflow as tf +from google import protobuf + + +def preview_tfrecord(filepath): + """Pretty prints a single record from a tfrecord file.""" + dataset = tf.data.TFRecordDataset(os.path.expanduser(filepath)) + example = tf.train.Example() + example.ParseFromString(next(iter(dataset)).numpy()) + formatted = protobuf.text_format.MessageToString( + example, use_short_repeated_primitives=True + ) + print(formatted) diff --git a/examples/bert/bert_utils.py b/examples/utils/scripting_utils.py similarity index 95% rename from examples/bert/bert_utils.py rename to examples/utils/scripting_utils.py index 83eb6ed722..586a5c127d 100644 --- a/examples/bert/bert_utils.py +++ b/examples/utils/scripting_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Utility files for BERT scripts.""" +"""Utility functions for writing training scripts.""" import glob import os diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index 4f00b2b14c..26fc815f11 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import tensorflow as tf -from google import protobuf def _decode_strings_to_utf8(inputs): @@ -48,14 +45,3 @@ def tensor_to_string_list(inputs): if inputs.shape.rank != 0: list_outputs = list_outputs.tolist() return _decode_strings_to_utf8(list_outputs) - - -def preview_tfrecord(filepath): - """Pretty prints a single record from a tfrecord file.""" - dataset = tf.data.TFRecordDataset(os.path.expanduser(filepath)) - example = tf.train.Example() - example.ParseFromString(next(iter(dataset)).numpy()) - formatted = protobuf.text_format.MessageToString( - example, use_short_repeated_primitives=True - ) - print(formatted)