In [83]:
import keras
from keras import layers, activations, ops
import tensorflow as tf
import os
#tf.config.experimental.set_jit_compiler_enabled(False)
#os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false'

In [84]:
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


In [85]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 5
num_epochs = 2 # It is recommended to run 50 epochs to observe improvements in accuracy
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 2  # Size of the patches to be extract from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
latent_dim = 256  # Size of the latent array.
projection_dim = 256  # Embedding size of each element in the data and latent arrays.
num_heads = 8  # Number of Transformer heads.
ffn_units = [
    projection_dim,
    projection_dim,
]  # Size of the Transformer Feedforward network.
num_transformer_blocks = 4
num_iterations = 2  # Repetitions of the cross-attention and Transformer modules.
classifier_units = [
    projection_dim,
    num_classes,
]  # Size of the Feedforward network of the final classifier.

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
print(f"Latent array shape: {latent_dim} X {projection_dim}")
print(f"Data array shape: {num_patches} X {projection_dim}")

Image size: 64 X 64 = 4096
Patch size: 2 X 2 = 4 
Patches per image: 1024
Elements per patch (3 channels): 12
Latent array shape: 256 X 256
Data array shape: 1024 X 256


In [86]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

In [87]:
def create_ffn(hidden_units, dropout_rate):
    ffn_layers = []
    for units in hidden_units[:-1]:
        ffn_layers.append(layers.Dense(units, activation=activations.gelu))

    ffn_layers.append(layers.Dense(units=hidden_units[-1]))
    ffn_layers.append(layers.Dropout(dropout_rate))

    ffn = keras.Sequential(ffn_layers)
    return ffn

In [88]:
def create_ffn(hidden_units, dropout_rate):
    ffn_layers = []
    for units in hidden_units[:-1]:
        ffn_layers.append(layers.Dense(units, activation=activations.gelu))

    ffn_layers.append(layers.Dense(units=hidden_units[-1]))
    ffn_layers.append(layers.Dropout(dropout_rate))

    ffn = keras.Sequential(ffn_layers)
    return ffn

In [89]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = ops.shape(images)[0]
        patches = ops.image.extract_patches(
            images=images,
            size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            dilation_rate=1,
            padding="valid",
        )
        patch_dims = patches.shape[-1]
        patches = ops.reshape(patches, [batch_size, -1, patch_dims])
        return patches

In [90]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        positions = ops.arange(start=0, stop=self.num_patches, step=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

In [91]:
def create_cross_attention_module(
    latent_dim, data_dim, projection_dim, ffn_units, dropout_rate
):
    inputs = {
        # Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
        "latent_array": layers.Input(
            shape=(latent_dim, projection_dim), name="latent_array"
        ),
        # Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
        "data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
    }

    # Apply layer norm to the inputs
    latent_array = layers.LayerNormalization(epsilon=1e-6)(inputs["latent_array"])
    data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])

    # Create query tensor: [1, latent_dim, projection_dim].
    query = layers.Dense(units=projection_dim)(latent_array)
    # Create key tensor: [batch_size, data_dim, projection_dim].
    key = layers.Dense(units=projection_dim)(data_array)
    # Create value tensor: [batch_size, data_dim, projection_dim].
    value = layers.Dense(units=projection_dim)(data_array)

    # Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
    attention_output = layers.Attention(use_scale=True, dropout=0.1)(
        [query, key, value], return_attention_scores=False
    )
    # Skip connection 1.
    attention_output = layers.Add()([attention_output, latent_array])

    # Apply layer norm.
    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
    # Apply Feedforward network.
    ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
    outputs = ffn(attention_output)
    # Skip connection 2.
    outputs = layers.Add()([outputs, attention_output])

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

In [92]:
def create_transformer_module(
    latent_dim,
    projection_dim,
    num_heads,
    num_transformer_blocks,
    ffn_units,
    dropout_rate,
):
    # input_shape: [1, latent_dim, projection_dim]
    inputs = layers.Input(shape=(latent_dim, projection_dim))

    x0 = inputs
    # Create multiple layers of the Transformer block.
    for _ in range(num_transformer_blocks):
        # Apply layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x0)
        # Create a multi-head self-attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x0])
        # Apply layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # Apply Feedforward network.
        ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
        x3 = ffn(x3)
        # Skip connection 2.
        x0 = layers.Add()([x3, x2])

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=x0)
    return model

In [93]:
class Perceiver(keras.Model):
    def __init__(
        self,
        patch_size,
        data_dim,
        latent_dim,
        projection_dim,
        num_heads,
        num_transformer_blocks,
        ffn_units,
        dropout_rate,
        num_iterations,
        classifier_units,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.data_dim = data_dim
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.num_heads = num_heads
        self.num_transformer_blocks = num_transformer_blocks
        self.ffn_units = ffn_units
        self.dropout_rate = dropout_rate
        self.num_iterations = num_iterations
        self.classifier_units = classifier_units

    def build(self, input_shape):
        # Create latent array.
        self.latent_array = self.add_weight(
            shape=(self.latent_dim, self.projection_dim),
            initializer="random_normal",
            trainable=True,
        )

        # Create patching module.
        self.patcher = Patches(self.patch_size)

        # Create patch encoder.
        self.patch_encoder = PatchEncoder(self.data_dim, self.projection_dim)

        # Create cross-attenion module.
        self.cross_attention = create_cross_attention_module(
            self.latent_dim,
            self.data_dim,
            self.projection_dim,
            self.ffn_units,
            self.dropout_rate,
        )

        # Create Transformer module.
        self.transformer = create_transformer_module(
            self.latent_dim,
            self.projection_dim,
            self.num_heads,
            self.num_transformer_blocks,
            self.ffn_units,
            self.dropout_rate,
        )

        # Create global average pooling layer.
        self.global_average_pooling = layers.GlobalAveragePooling1D()

        # Create a classification head.
        self.classification_head = create_ffn(
            hidden_units=self.classifier_units, dropout_rate=self.dropout_rate
        )

        super().build(input_shape)

    def call(self, inputs):
        # Augment data.
        augmented = data_augmentation(inputs)
        # Create patches.
        patches = self.patcher(augmented)
        # Encode patches.
        encoded_patches = self.patch_encoder(patches)
        # Prepare cross-attention inputs.
        cross_attention_inputs = {
            "latent_array": ops.expand_dims(self.latent_array, 0),
            "data_array": encoded_patches,
        }
        # Apply the cross-attention and the Transformer modules iteratively.
        for _ in range(self.num_iterations):
            # Apply cross-attention from the latent array to the data array.
            latent_array = self.cross_attention(cross_attention_inputs)
            # Apply self-attention Transformer to the latent array.
            latent_array = self.transformer(latent_array)
            # Set the latent array of the next iteration.
            cross_attention_inputs["latent_array"] = latent_array

        # Apply global average pooling to generate a [batch_size, projection_dim] repesentation tensor.
        representation = self.global_average_pooling(latent_array)
        # Generate logits.
        logits = self.classification_head(representation)
        return logits

In [94]:
def run_experiment(model):
    # Create ADAM instead of LAMB optimizer with weight decay. (LAMB isn't supported yet)
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    # Compile the model.
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )

    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.2, patience=3
    )

    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=15, restore_best_weights=True
    )

    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history

In [95]:
perceiver_classifier = Perceiver(
    patch_size,
    num_patches,
    latent_dim,
    projection_dim,
    num_heads,
    num_transformer_blocks,
    ffn_units,
    dropout_rate,
    num_iterations,
    classifier_units,
)


history = run_experiment(perceiver_classifier)

Epoch 1/2


2025-10-21 12:20:58.010836: W tensorflow/core/framework/op_kernel.cc:1855] OP_REQUIRES failed at xla_ops.cc:590 : INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph __inference_one_step_on_data_137092[] on XLA_GPU_JIT: ImageProjectiveTransformV3 (No registered 'ImageProjectiveTransformV3' OpKernel for XLA_GPU_JIT devices compatible with node {{node perceiver_9_1/data_augmentation_1/random_zoom_3_1/ImageProjectiveTransformV3}}){{node perceiver_9_1/data_augmentation_1/random_zoom_3_1/ImageProjectiveTransformV3}}
The op is created at: 
File "usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "usr/lib/python3.10/runpy.py", line 86, in _run_code
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "home/xy/Desktop/ml/keras/test/keras_tes

InvalidArgumentError: Graph execution error:

Detected at node perceiver_9_1/data_augmentation_1/random_zoom_3_1/ImageProjectiveTransformV3 defined at (most recent call last):
<stack traces unavailable>
Detected at node perceiver_9_1/data_augmentation_1/random_zoom_3_1/ImageProjectiveTransformV3 defined at (most recent call last):
<stack traces unavailable>
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_137092[] on XLA_GPU_JIT: ImageProjectiveTransformV3 (No registered 'ImageProjectiveTransformV3' OpKernel for XLA_GPU_JIT devices compatible with node {{node perceiver_9_1/data_augmentation_1/random_zoom_3_1/ImageProjectiveTransformV3}}){{node perceiver_9_1/data_augmentation_1/random_zoom_3_1/ImageProjectiveTransformV3}}
The op is created at: 
File "usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "usr/lib/python3.10/runpy.py", line 86, in _run_code
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
File "usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "usr/lib/python3.10/asyncio/events.py", line 80, in _run
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 519, in dispatch_queue
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 508, in process_one
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 400, in dispatch_shell
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 368, in execute_request
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 767, in execute_request
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 455, in do_execute
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 577, in run_cell
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3077, in run_cell
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3132, in _run_cell
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3336, in run_cell_async
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3519, in run_ast_nodes
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3579, in run_code
File "tmp/ipykernel_13395/1091006353.py", line 15, in <module>
File "tmp/ipykernel_13395/2824899570.py", line 26, in run_experiment
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 889, in _call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1060, in func_graph_from_py_func
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 339, in converted_call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 833, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 906, in _call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 132, in call_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1060, in func_graph_from_py_func
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 331, in converted_call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1673, in run
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3263, in call_for_each_replica
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/distribute/distribute_lib.py", line 4061, in _call_for_each_replica
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 60, in train_step
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/layer.py", line 936, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/ops/operation.py", line 58, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "tmp/ipykernel_13395/34977848.py", line 73, in call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/layer.py", line 936, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/ops/operation.py", line 58, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/models/sequential.py", line 220, in call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/models/functional.py", line 183, in call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/ops/function.py", line 177, in _run_through_graph
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/models/functional.py", line 648, in call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/preprocessing/tf_data_layer.py", line 43, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/layer.py", line 936, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/ops/operation.py", line 58, in __call__
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py", line 208, in call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py", line 179, in transform_images
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py", line 376, in _zoom_inputs
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/keras/src/backend/tensorflow/image.py", line 361, in affine_transform
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/util/tf_export.py", line 377, in wrapper
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/ops/gen_image_ops.py", line 2592, in image_projective_transform_v3
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 614, in _create_op_internal
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 2726, in _create_op_internal
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 1221, in from_node_def
	tf2xla conversion failed while converting __inference_one_step_on_data_137092[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
	 [[StatefulPartitionedCall]] [Op:__inference_multi_step_on_iterator_137678]

/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 889, in _call
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
File "home/xy/Desktop/ml/keras/test/keras_test/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1060, 