<a href="https://colab.research.google.com/github/haifeng-jin/Colabs/blob/main/keras_issue_19583.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install keras==3.3.2

Collecting keras==3.3.2
  Using cached keras-3.3.2-py3-none-any.whl (1.1 MB)
Installing collected packages: keras
  Attempting uninstall: keras
    Found existing installation: keras 3.2.0
    Uninstalling keras-3.2.0:
      Successfully uninstalled keras-3.2.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.2 which is incompatible.[0m[31m
[0mSuccessfully installed keras-3.3.2


In [2]:
"""Reproduce automatic casting tf.uint8 to tf.float32 in symbolic execution"""

import tensorflow as tf
import keras

print(f"TensorFlow version: {tf.version.VERSION}")
print(f"Keras version: {keras.version()}")

ds_tensor = tf.constant([1], dtype=tf.uint8)

ds = tf.data.Dataset.from_tensors(ds_tensor)
print(f"Original dataset: {repr(ds)}")
print(f"Original dataset items: {repr(list(ds))}")

category_encoding = keras.layers.CategoryEncoding(2, output_mode="one_hot")

# This is what I wanted, as it reduces the memory bandwidth required to send
# training data to the GPU
category_encoding_model = keras.Sequential(
    [
        keras.layers.Input(tf.TensorShape((1,)), dtype=tf.uint8),
        category_encoding,
    ]
)
category_encoding_model.summary()
print(
    f"Category encoding model input dtype: {repr(category_encoding_model.input_dtype)}"
)

try:
    category_encoding_model.predict(ds)
except TypeError as e:
    print(f"Could not predict with uint8 dataset: {e}")

print()

# Baring that I can try to use it in the tf.data API instead.
try:
    ds.map(category_encoding)
except TypeError as e:
    print(f"Could not map dataset: {e}")

print()

# Lets try to reproduce this out of keras
@tf.function
def simple_one_hot(indices: tf.Tensor, depth: tf.Tensor) -> tf.Tensor:
    """Just 1-hot encode the input indices symbolicly"""

    assert tf.is_symbolic_tensor(indices)
    return tf.one_hot(indices, depth)


nokeras_encoded = simple_one_hot(ds_tensor, tf.constant(2, dtype=tf.int32))
print(f"Symbolicly encoded categories using tf.one_hot: {repr(nokeras_encoded)}")
# Looks like it's a keras bug, or something more complex is happening behind the
# scenes of CategoryEncoding


# Lets try to narrow that down:
@tf.function
def keras_one_hot(indices: tf.Tensor, depth: tf.Tensor) -> tf.Tensor:
    """Just 1-hot encode the input indices symbolicly, but using the keras tensorflow backend"""

    backend = category_encoding.backend
    return backend.nn.one_hot(indices, depth)


kerasnn_encoded = keras_one_hot(ds_tensor, tf.constant(2, dtype=tf.int32))
print(
    f"Symbolicly encoded categories using keras backend one_hot: {repr(kerasnn_encoded)}"
)


# Maybe the problem is here:
# https://github.com/keras-team/keras/blob/f77b020/keras/layers/preprocessing/tf_data_layer.py#L30
@tf.function
def map_structure(inputs: tf.Tensor) -> any:
    """Recreate the map structure call in TFDataLayer"""

    backend = category_encoding.backend
    return keras.tree.map_structure(
        lambda x: backend.convert_to_tensor(x, dtype=category_encoding.compute_dtype),
        inputs,
    )


print(f"CategoryEncoding compute dtype: {repr(category_encoding.compute_dtype)}")
mapped_indices = map_structure(ds_tensor)
print(f"keras.utils.tree.map_structure with converted dtype: {repr(mapped_indices)}")
# Looks like thats it—the compute_dtype property is the _output dtype_—which is often the input
# dtype ... but not for CategoryEncoding.

# Why does the behavior differ when executing eagerly though? 🤔
# Because of this condition:
# https://github.com/keras-team/keras/blob/f77b020/keras/layers/preprocessing/tf_data_layer.py#L23
eager_encoded = category_encoding(ds_tensor)
print(f"Eagerly encoded categories: {repr(eager_encoded)}")

TensorFlow version: 2.15.0
Keras version: 3.3.2
Original dataset: <_TensorDataset element_spec=TensorSpec(shape=(1,), dtype=tf.uint8, name=None)>
Original dataset items: [<tf.Tensor: shape=(1,), dtype=uint8, numpy=array([1], dtype=uint8)>]


Category encoding model input dtype: 'uint8'
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 117ms/step


Symbolicly encoded categories using tf.one_hot: <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 1.]], dtype=float32)>
Symbolicly encoded categories using keras backend one_hot: <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 1.]], dtype=float32)>
CategoryEncoding compute dtype: 'float32'
keras.utils.tree.map_structure with converted dtype: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>
Eagerly encoded categories: <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 1.]], dtype=float32)>
