diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index ade085e399..3d21e9ca85 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -175,9 +175,9 @@ def __init__( kwargs["dtype"] = tf.int32 else: dtype = tf.dtypes.as_dtype(kwargs["dtype"]) - if not dtype.is_integer and dtype != tf.string: + if not dtype.is_integer: raise ValueError( - "Output dtype must be an integer type of a string. " + "Output dtype must be an integer type. " f"Received: dtype={dtype}" ) @@ -251,6 +251,7 @@ def tokenize(self, inputs): replacement_char=self.replacement_char, input_encoding=self.input_encoding, ) + tokens = tf.cast(tokens, self.compute_dtype) if self.sequence_length: output_shape = tokens.shape.as_list()