Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions examples/bert/run_glue_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json

import datasets
import keras_tuner
import tensorflow as tf
import tensorflow_text as tftext
from absl import app
Expand Down Expand Up @@ -68,8 +69,6 @@

flags.DEFINE_integer("epochs", 3, "The number of training epochs.")

flags.DEFINE_float("learning_rate", 2e-5, "The initial learning rate for Adam.")

flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")


Expand Down Expand Up @@ -167,9 +166,32 @@ def call(self, inputs):
return self._logit_layer(outputs)


class BertHyperModel(keras_tuner.HyperModel):
"""Creates a hypermodel to help with the search space for finetuning."""

def __init__(self, bert_config):
self.bert_config = bert_config

def build(self, hp):
model = keras.models.load_model(FLAGS.saved_model_input, compile=False)
bert_config = self.bert_config
finetuning_model = BertClassificationFinetuner(
bert_model=model,
hidden_size=bert_config["hidden_size"],
num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2,
)
finetuning_model.compile(
optimizer=keras.optimizers.Adam(
learning_rate=hp.Choice("lr", [5e-5, 4e-5, 3e-5, 2e-5])
),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
return finetuning_model


def main(_):
print(f"Reading input model from {FLAGS.saved_model_input}")
model = keras.models.load_model(FLAGS.saved_model_input)

vocab = []
with open(FLAGS.vocab_file, "r") as vocab_file:
Expand Down Expand Up @@ -200,6 +222,7 @@ def preprocess_data(inputs, labels):

# Read and preprocess GLUE task data.
train_ds, test_ds, validation_ds = load_data(FLAGS.task_name)

train_ds = train_ds.batch(FLAGS.batch_size).map(
preprocess_data, num_parallel_calls=tf.data.AUTOTUNE
)
Expand All @@ -210,18 +233,27 @@ def preprocess_data(inputs, labels):
preprocess_data, num_parallel_calls=tf.data.AUTOTUNE
)

finetuning_model = BertClassificationFinetuner(
bert_model=model,
hidden_size=bert_config["hidden_size"],
num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2,
)
finetuning_model.compile(
optimizer=keras.optimizers.Adam(learning_rate=FLAGS.learning_rate),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
# Create a hypermodel object for a RandomSearch.
hypermodel = BertHyperModel(bert_config)

# Initialize the random search over the 4 learning rate parameters, for 4
# trials and 3 epochs for each trial.
tuner = keras_tuner.RandomSearch(
hypermodel=hypermodel,
objective=keras_tuner.Objective("val_loss", direction="min"),
max_trials=4,
overwrite=True,
project_name="hyperparameter_tuner_results",
)
finetuning_model.fit(
train_ds, epochs=FLAGS.epochs, validation_data=validation_ds

tuner.search(train_ds, epochs=FLAGS.epochs, validation_data=validation_ds)

# Extract the best hyperparameters after the search.
best_hp = tuner.get_best_hyperparameters()[0]
finetuning_model = tuner.get_best_models()[0]

print(
f"The best hyperparameters found are:\nLearning Rate: {best_hp['lr']}"
)

if FLAGS.do_evaluation:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"datasets", # For GLUE in BERT example.
"nltk",
"wikiextractor",
"keras-tuner",
],
},
classifiers=[
Expand Down