diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 06afe69361..39173c72ef 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -15,5 +15,6 @@ from keras_nlp import layers from keras_nlp import metrics from keras_nlp import tokenizers +from keras_nlp import utils __version__ = "0.1.1" diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py new file mode 100644 index 0000000000..26fc815f11 --- /dev/null +++ b/keras_nlp/utils/tensor_utils.py @@ -0,0 +1,47 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + + +def _decode_strings_to_utf8(inputs): + """Recursively decodes to list of strings with 'utf-8' encoding.""" + if isinstance(inputs, bytes): + # Handles the case when the input is a scalar string. + return inputs.decode("utf-8") + else: + # Recursively iterate when input is a list. + return [_decode_strings_to_utf8(x) for x in inputs] + + +def tensor_to_string_list(inputs): + """Detokenize and convert tensor to nested lists of python strings. + + This is a convenience method which converts each byte string to a python + string. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): + inputs = tf.convert_to_tensor(inputs) + if isinstance(inputs, tf.RaggedTensor): + list_outputs = inputs.to_list() + elif isinstance(inputs, tf.Tensor): + list_outputs = inputs.numpy() + if inputs.shape.rank != 0: + list_outputs = list_outputs.tolist() + return _decode_strings_to_utf8(list_outputs) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tensor_utils_test.py new file mode 100644 index 0000000000..2c9103f3fc --- /dev/null +++ b/keras_nlp/utils/tensor_utils_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensor_utils import tensor_to_string_list + + +class TensorToStringListTest(tf.test.TestCase): + def test_detokenize_to_strings_for_ragged(self): + input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) + detokenize_output = tensor_to_string_list(input_data) + self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) + + def test_detokenize_to_strings_for_dense(self): + input_data = tf.constant([["▀▁▂▃", "samurai"]]) + detokenize_output = tensor_to_string_list(input_data) + self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) + + def test_detokenize_to_strings_for_scalar(self): + input_data = tf.constant("▀▁▂▃") + detokenize_output = tensor_to_string_list(input_data) + self.assertEqual(detokenize_output, "▀▁▂▃")