Skip to content

Commit

Permalink
repeat() for validation data, not dropping remainder
Browse files Browse the repository at this point in the history
  • Loading branch information
cfregly committed Apr 2, 2020
1 parent 3668089 commit e5b2ea9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions 06_train/new_src_bert_tf2/tf_bert_reviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def file_based_input_dataset_builder(channel,
print('***** Using input_filenames {}'.format(input_filenames))
dataset = tf.data.TFRecordDataset(input_filenames)

dataset = dataset.repeat()

if is_training:
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=1)
dataset = dataset.shuffle(buffer_size=1000)

name_to_features = {
"input_ids": tf.io.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
Expand Down Expand Up @@ -140,7 +141,7 @@ def _decode_record(record, name_to_features):
input_filenames=train_data_filenames,
pipe_mode=pipe_mode,
is_training=True,
drop_remainder=True).map(select_data_and_label_from_record)
drop_remainder=False).map(select_data_and_label_from_record)

validation_data_filenames = glob('{}/*.tfrecord'.format(validation_data))
print(validation_data_filenames)
Expand All @@ -149,7 +150,7 @@ def _decode_record(record, name_to_features):
input_filenames=validation_data_filenames,
pipe_mode=pipe_mode,
is_training=False,
drop_remainder=True).map(select_data_and_label_from_record)
drop_remainder=False).map(select_data_and_label_from_record)

tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
Expand Down

0 comments on commit e5b2ea9

Please sign in to comment.