diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index 26fc815f11..4f00b2b14c 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import tensorflow as tf +from google import protobuf def _decode_strings_to_utf8(inputs): @@ -45,3 +48,14 @@ def tensor_to_string_list(inputs): if inputs.shape.rank != 0: list_outputs = list_outputs.tolist() return _decode_strings_to_utf8(list_outputs) + + +def preview_tfrecord(filepath): + """Pretty prints a single record from a tfrecord file.""" + dataset = tf.data.TFRecordDataset(os.path.expanduser(filepath)) + example = tf.train.Example() + example.ParseFromString(next(iter(dataset)).numpy()) + formatted = protobuf.text_format.MessageToString( + example, use_short_repeated_primitives=True + ) + print(formatted)