-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Summary
On the TensorFlow backend, keras.ops.histogram crashes in graph mode because the TF implementation converts a tensor to a Python list via .numpy().tolist(). This is not graph-safe for symbolic tensors in a Keras Functional/Model context. The same code works on the JAX backend.
Minimal reproducible example
# Repro: Keras 3 + TensorFlow backend, graph mode, jit or not
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
tf.config.run_functions_eagerly(False) # ensure graph
class HistogramProbe(keras.layers.Layer):
def __init__(self, bins=8, lo=0.0, hi=1.0, **kwargs):
super().__init__(**kwargs)
self.bins = int(bins)
self.lo = float(lo)
self.hi = float(hi)
def call(self, x):
flat = keras.ops.reshape(x, (-1,))
# This triggers TF backend histogram implementation which does .numpy().tolist()
counts, edges = keras.ops.histogram(flat, bins=self.bins, range=(self.lo, self.hi))
return x # identity on data path
inp = keras.Input(shape=(16,), name="inp")
h = keras.layers.Dense(8, activation="relu")(inp)
h = HistogramProbe(name="probe")(h)
out = keras.layers.Dense(1)(h)
model = keras.Model(inp, out)
model.compile(optimizer="adam", loss="mse", jit_compile=True) # jit_compile can be True or False
x = tf.random.uniform((32, 16))
y = tf.zeros((32, 1))
# Fails with AttributeError in graph mode
model.fit(x, y, epochs=1, batch_size=8)See https://colab.research.google.com/drive/19R8qt7UjmX6Qz4YNkmz7KWWe3wO27MEw?usp=sharing
Observed error (trimmed)
AttributeError: Exception encountered when calling HistogramProbe.call().
Could not automatically infer the output shape / dtype of 'probe' ...
Error encountered:
'SymbolicTensor' object has no attribute 'numpy'
Arguments received by HistogramProbe.call():
• args=('<KerasTensor shape=(None, 8), dtype=float32, ...>',)
Root cause (source references)
In the TF backend implementation of keras.ops.histogram, the code does:
bin_edges = tf.linspace(min_val, max_val, bins + 1)
bin_edges_list = bin_edges.numpy().tolist() # <-- requires eager tensor
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])Calling .numpy() on a symbolic tensor in a graph causes the crash.
Expected behavior
keras.ops.histogram should be graph-safe on the TensorFlow backend (or clearly documented as eager-only). Ideally it should run inside a Keras model without errors.
Environment
Please replace the versions with your actual outputs:
keras: 3.11.3
tensorflow: 2.19.0
jax: 0.5.3
python: 3.12.11 (main, Jun 4 2025, 08:56:18) [GCC 11.4.0]
platform: Linux-6.6.97+-x86_64-with-glibc2.35
backend: tensorflow
Workarounds
Something like:
def tf_histogram_fixed_width_1d(x, bins, range):
lo = range[0]
hi = range[1]
x = tf.reshape(x, [-1])
x = tf.clip_by_value(x, lo, hi)
counts = tf.histogram_fixed_width(values=x, value_range=[lo, hi], nbins=bins)
return counts, tf.linspace(lo, hi, bins + 1)However, I think there might be slight edge issues around lo and hi.
Notes
- The same repro runs fine on the JAX backend.
- The failure happens whether
jit_compile=TrueorFalse, as long as the call happens in graph mode / symbolic context.
Thanks!