Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IntegerLookup with XLA Compilation Fails to Enable JIT in TensorFlow 2.16.1 and Keras 3.2.1 #774

Closed
jacob-talroo opened this issue Apr 12, 2024 · 3 comments

Comments

@jacob-talroo
Copy link

System information

  • Custom code: Yes
  • OS Platform and Distribution: Google Colab
  • TensorFlow installed from: Binary
  • TensorFlow version: 2.16.1
  • Keras version: 3.2.1
  • Python version: 3.10.12
  • GPU model and memory: Colab default GPU setup
  • Exact command to reproduce: See the Colab notebook and source code provided below.

Describe the problem:
The issue with IntegerLookup layer not properly utilizing XLA when jit_compile=True is set persists in TensorFlow 2.16.1 and Keras 3.2.1. The expectation is for the model to be compiled with XLA to enhance training efficiency, but the system defaults to disabling XLA despite explicit JIT compilation settings.

Describe the current behavior:
When jit_compile=True is set, the Keras backend appears to disable XLA, and training proceeds without the expected XLA optimization, leading to a less efficient training process.

Describe the expected behavior:
The model should honor the jit_compile=True setting and compile with XLA, resulting in a significantly more efficient training process.

Standalone code to reproduce the issue:
Here is a Colab Notebook demonstrating the issue. Below is the minimal code to reproduce the behavior:

import tensorflow as tf
import keras

def make_dataset():
    # Dataset creation code
    pass

def make_model():
    vocabulary = list(range(1, 11))
    return keras.Sequential([
        keras.layers.IntegerLookup(vocabulary=vocabulary),
        keras.layers.Embedding(len(vocabulary) + 2, 8, input_length=1),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])

# Model and training setup
dataset = make_dataset()
model = make_model()
model.compile(loss=keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.AdamW(), metrics=['AUC'], jit_compile=True)
model.fit(dataset, steps_per_epoch=100, epochs=2, verbose=1)

Source code / logs:
Please refer to the provided Colab notebook for full logs and traceback. The key message indicating the problem is:

UserWarning: Model doesn't support `jit_compile=True`. Proceeding with `jit_compile=False`.

Additional context:
This problem was previously reported at #418 and believed to be resolved; however, further testing shows that the issue still exists, indicating a need for re-evaluation by the development team.

@tilakrayal
Copy link
Collaborator

@jacob-talroo,
I tried with both tensorflow v2.15 & v2.16 and observed that on v2.15 there was no warning while execution but whereas on v2.16, the warning is displayed. I suspect this is happening with keras 3.0 version which is default to the 2.16. With tensorflow 2.15 the default keras version is 2.15(keras). Kindly find the gist for the reference and proceed creating the new request on keras-team/keras repo, since it is happening on keras3. Thank you!

@jacob-talroo
Copy link
Author

Thank you - I just filed keras-team/keras#19521 to track.

@tilakrayal
Copy link
Collaborator

@jacob-talroo,
Could you please move this issue to closed status, as it has been tracked there. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants