From f41d2fb5ad5065c073798289b0456e1cc8ed8f9f Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:20:04 -0700 Subject: [PATCH 1/2] Preprocessing decorator fixes (#1843) * Fix handling bytesting input to tokenizers, preprocessing * Fix no convert scope in multithreaded contexts --- keras_nlp/src/utils/tensor_utils.py | 11 +++++------ keras_nlp/src/utils/tensor_utils_test.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/keras_nlp/src/utils/tensor_utils.py b/keras_nlp/src/utils/tensor_utils.py index 7502c38bcf..26d603a5d2 100644 --- a/keras_nlp/src/utils/tensor_utils.py +++ b/keras_nlp/src/utils/tensor_utils.py @@ -30,20 +30,19 @@ NO_CONVERT_COUNTER = threading.local() -NO_CONVERT_COUNTER.count = 0 @contextlib.contextmanager def no_convert_scope(): try: - NO_CONVERT_COUNTER.count += 1 + NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1 yield finally: - NO_CONVERT_COUNTER.count -= 1 + NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1 def in_no_convert_scope(): - return NO_CONVERT_COUNTER.count > 0 + return getattr(NO_CONVERT_COUNTER, "count", 0) > 0 def preprocessing_function(fn): @@ -119,7 +118,7 @@ def convert_preprocessing_inputs(x): return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()} if isinstance(x, tuple): return tuple(convert_preprocessing_inputs(v) for v in x) - if isinstance(x, str): + if isinstance(x, (str, bytes)): return tf.constant(x) if isinstance(x, list): try: @@ -132,7 +131,7 @@ def convert_preprocessing_inputs(x): # If ragged conversion failed return to the numpy error. raise e # If we have a string input, use tf.tensor. - if numpy_x.dtype.type is np.str_: + if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_: return tf.convert_to_tensor(x) # Numpy will default to int64, int32 works with more ops. if numpy_x.dtype == np.int64: diff --git a/keras_nlp/src/utils/tensor_utils_test.py b/keras_nlp/src/utils/tensor_utils_test.py index 463a267292..e9c5e97844 100644 --- a/keras_nlp/src/utils/tensor_utils_test.py +++ b/keras_nlp/src/utils/tensor_utils_test.py @@ -49,6 +49,17 @@ def test_strings(self): self.assertIsInstance(outputs, list) self.assertEqual(outputs, inputs) + def test_bytestrings(self): + inputs = ["one".encode("utf-8"), "two".encode("utf-8")] + # Convert to tf. + outputs = convert_preprocessing_inputs(inputs) + self.assertIsInstance(outputs, tf.Tensor) + self.assertAllEqual(outputs, tf.constant(inputs)) + # Convert from tf. + outputs = convert_preprocessing_outputs(outputs) + self.assertIsInstance(outputs, list) + self.assertEqual(outputs, [x.decode("utf-8") for x in inputs]) + def test_ragged(self): inputs = [np.ones((1, 3)), np.ones((1, 2))] # Convert to tf. From ad6f259fcb6cc2cd09215d5feb8f576ee827d5cb Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Thu, 19 Sep 2024 00:07:28 -0700 Subject: [PATCH 2/2] Version bump dev release --- keras_nlp/src/version_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/src/version_utils.py b/keras_nlp/src/version_utils.py index 2c868b4750..5e2156102c 100644 --- a/keras_nlp/src/version_utils.py +++ b/keras_nlp/src/version_utils.py @@ -15,7 +15,7 @@ from keras_nlp.src.api_export import keras_nlp_export # Unique source of truth for the version number. -__version__ = "0.15.1.dev0" +__version__ = "0.15.1.dev1" @keras_nlp_export("keras_nlp.version")