Skip to content

TF backend: keras.ops.histogram calls .numpy() in graph mode and crashes with AttributeError on symbolic tensors #21708

@jacob-talroo

Description

@jacob-talroo

Summary

On the TensorFlow backend, keras.ops.histogram crashes in graph mode because the TF implementation converts a tensor to a Python list via .numpy().tolist(). This is not graph-safe for symbolic tensors in a Keras Functional/Model context. The same code works on the JAX backend.

Minimal reproducible example

# Repro: Keras 3 + TensorFlow backend, graph mode, jit or not
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras

tf.config.run_functions_eagerly(False)  # ensure graph

class HistogramProbe(keras.layers.Layer):
    def __init__(self, bins=8, lo=0.0, hi=1.0, **kwargs):
        super().__init__(**kwargs)
        self.bins = int(bins)
        self.lo = float(lo)
        self.hi = float(hi)

    def call(self, x):
        flat = keras.ops.reshape(x, (-1,))
        # This triggers TF backend histogram implementation which does .numpy().tolist()
        counts, edges = keras.ops.histogram(flat, bins=self.bins, range=(self.lo, self.hi))
        return x  # identity on data path

inp = keras.Input(shape=(16,), name="inp")
h = keras.layers.Dense(8, activation="relu")(inp)
h = HistogramProbe(name="probe")(h)
out = keras.layers.Dense(1)(h)
model = keras.Model(inp, out)

model.compile(optimizer="adam", loss="mse", jit_compile=True)  # jit_compile can be True or False
x = tf.random.uniform((32, 16))
y = tf.zeros((32, 1))

# Fails with AttributeError in graph mode
model.fit(x, y, epochs=1, batch_size=8)

See https://colab.research.google.com/drive/19R8qt7UjmX6Qz4YNkmz7KWWe3wO27MEw?usp=sharing

Observed error (trimmed)

AttributeError: Exception encountered when calling HistogramProbe.call().

Could not automatically infer the output shape / dtype of 'probe' ...
Error encountered:
'SymbolicTensor' object has no attribute 'numpy'

Arguments received by HistogramProbe.call():
  • args=('<KerasTensor shape=(None, 8), dtype=float32, ...>',)

Root cause (source references)
In the TF backend implementation of keras.ops.histogram, the code does:

bin_edges = tf.linspace(min_val, max_val, bins + 1)
bin_edges_list = bin_edges.numpy().tolist()  # <-- requires eager tensor
bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])

https://github.com/keras-team/keras/blob/b491c860fc2750e2b6006b55358d3251dbb4a9f0/keras/src/backend/tensorflow/numpy.py#L2966C1-L2966C48

Calling .numpy() on a symbolic tensor in a graph causes the crash.

Expected behavior

keras.ops.histogram should be graph-safe on the TensorFlow backend (or clearly documented as eager-only). Ideally it should run inside a Keras model without errors.

Environment

Please replace the versions with your actual outputs:

keras: 3.11.3
tensorflow: 2.19.0
jax: 0.5.3
python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
platform: Linux-6.6.97+-x86_64-with-glibc2.35
backend: tensorflow

Workarounds

Something like:

def tf_histogram_fixed_width_1d(x, bins, range):
    lo = range[0]
    hi = range[1]
    x = tf.reshape(x, [-1])
    x = tf.clip_by_value(x, lo, hi)
    counts = tf.histogram_fixed_width(values=x, value_range=[lo, hi], nbins=bins)
    return counts, tf.linspace(lo, hi, bins + 1)

However, I think there might be slight edge issues around lo and hi.

Notes

  • The same repro runs fine on the JAX backend.
  • The failure happens whether jit_compile=True or False, as long as the call happens in graph mode / symbolic context.

Thanks!

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions