Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/bert/run_glue_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import datasets
import keras_tuner
import tensorflow as tf
import tensorflow_text as tftext
import tensorflow_text as tf_text
from absl import app
from absl import flags
from tensorflow import keras
Expand Down Expand Up @@ -81,20 +81,20 @@ def pack_inputs(
):
# In case inputs weren't truncated (as they should have been),
# fall back to some ad-hoc truncation.
trimmed_segments = tftext.RoundRobinTrimmer(
trimmed_segments = tf_text.RoundRobinTrimmer(
seq_length - len(inputs) - 1
).trim(inputs)
# Combine segments.
segments_combined, segment_ids = tftext.combine_segments(
segments_combined, segment_ids = tf_text.combine_segments(
trimmed_segments,
start_of_sequence_id=start_of_sequence_id,
end_of_segment_id=end_of_segment_id,
)
# Pad to dense Tensors.
input_word_ids, _ = tftext.pad_model_inputs(
input_word_ids, _ = tf_text.pad_model_inputs(
segments_combined, seq_length, pad_value=padding_id
)
input_type_ids, input_mask = tftext.pad_model_inputs(
input_type_ids, input_mask = tf_text.pad_model_inputs(
segment_ids, seq_length, pad_value=0
)
# Assemble nest of input tensors as expected by BERT model.
Expand Down Expand Up @@ -184,8 +184,8 @@ def build(self, hp):
optimizer=keras.optimizers.Adam(
learning_rate=hp.Choice("lr", [5e-5, 4e-5, 3e-5, 2e-5])
),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return finetuning_model

Expand All @@ -197,7 +197,7 @@ def main(_):
with open(FLAGS.vocab_file, "r") as vocab_file:
for line in vocab_file:
vocab.append(line.strip())
tokenizer = tftext.BertTokenizer(
tokenizer = tf_text.BertTokenizer(
FLAGS.vocab_file,
lower_case=FLAGS.do_lower_case,
token_out_type=tf.int32,
Expand Down