Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to create model with Keras=3.2.1 and jax[cpu]=0.4.26 #19539

Open
veb-101 opened this issue Apr 17, 2024 · 3 comments
Open

Unable to create model with Keras=3.2.1 and jax[cpu]=0.4.26 #19539

veb-101 opened this issue Apr 17, 2024 · 3 comments
Assignees
Labels
backend:jax stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@veb-101
Copy link

veb-101 commented Apr 17, 2024

Hello, I'm attempting to create the MobileViT model using the Keras 3 backend.

There are two main aspects:

  1. Creating a custom Layer class.
  2. Utilizing the custom Layer class in a model via the functional paradigm.

Here is my code: Keras3_jax_windows.py

The code functions correctly on Colab (where I installed the latest versions of Jax and Keras), so I assume it should also work well on Linux machines.

On the Windows machine, it operates as intended with the TensorFlow backend. However, when using the Jax backend, the model creation fails. Interestingly, initializing the custom layer works properly, and I can even perform a forward pass through it.

Code snippet:

def create_test_model():
    # Define input shape
    input_shape = (32, 32, 96)  # Assuming input shape is (32, 32, 96), adjust as necessary

    # Create input layer
    inputs = Input(shape=input_shape)

    print("inputs.shape", inputs.shape)
    mvitblk = MobileViT_v1_Block(
        out_filters=96,
        embedding_dim=144,
        transformer_repeats=2,
        name="MobileViTBlock-1",
        attention_drop=0.2,
        linear_drop=0.2,
    )(inputs)

    # Create model
    model = Model(inputs=inputs, outputs=mvitblk)

    # Print model summary
    model.summary()
    return model


if __name__ == "__main__":
    batch = 1
    H = W = 32
    C = 96
    P = 2 * 2
    L = 4
    embedding_dim = 144

    print("Testing layer initialization...")
    mvitblk = MobileViT_v1_Block(
        out_filters=96,
        embedding_dim=144,
        transformer_repeats=2,
        name="MobileViTBlock-1",
        attention_drop=0.2,
        linear_drop=0.2,
    )

    inputs = keras.random.normal((batch, H, W, C))

    out = mvitblk(inputs)
    print("inputs.shape", inputs.shape)
    print("out.shape", out.shape)

    print("-------------------------------------------\n" * 3)
    print("Testing layer inside Model.")
    # Test model creation
    create_test_model()

Output:

Testing layer initialization...
x.shape (1, 32, 32, 96) (1, 32, 32, 96)
B 1
x.shape (1, 32, 32, 96) (1, 32, 32, 96)
B 1
inputs.shape (1, 32, 32, 96)
out.shape (1, 32, 32, 96)
-------------------------------------------
-------------------------------------------
-------------------------------------------

Testing layer inside Model.
inputs.shape (None, 32, 32, 96)
x.shape (83, 32, 32, 96) (83, 32, 32, 96)
B 83
C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\layers\layer.py:1248: UserWarning: Layer 'MobileViTBlock-1' looks like it has unbuilt state, but Keras is not able to trace the layer `call()` in order to build it automatically. Possible causes:
1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 32, 96).''
  warnings.warn(
C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\layers\layer.py:357: UserWarning: `build()` was called on layer 'MobileViTBlock-1', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
x.shape (83, 32, 32, 96) (83, 32, 32, 96)
B 83
Traceback (most recent call last):
  File "C:\Users\vaibh\OneDrive\Desktop\Work\other_work_mine\Mine\keras-vision\keras_vision\MobileViT_v1\minimal_example.py", line 276, in <module>
    create_test_model()
  File "C:\Users\vaibh\OneDrive\Desktop\Work\other_work_mine\Mine\keras-vision\keras_vision\MobileViT_v1\minimal_example.py", line 232, in create_test_model
    mvitblk = MobileViT_v1_Block(
              ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\backend\jax\core.py", line 253, in compute_output_spec
    _, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Exception encountered when calling MobileViT_v1_Block.call().

Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 32, 96).

Arguments received by MobileViT_v1_Block.call():
  • args=('<KerasTensor shape=(None, 32, 32, 96), dtype=float32, sparse=None, name=keras_tensor_4>',)
  • kwargs=<class 'inspect._empty'>

I know the code fails at this point:

def unfolding(
    x,
    B: int = 1,
    D: int = 144,
    patch_h: int = 2,
    patch_w: int = 2,
    num_patches_h: int = 10,
    num_patches_w: int = 10,
):
    """
    ### Notations (wrt paper) ###
        B/b = batch
        P/p = patch_size
        N/n = number of patches
        D/d = embedding_dim
    H, W
    [                            [
        [1, 2, 3, 4],     Goal      [1, 3, 9, 11],
        [5, 6, 7, 8],     ====>     [2, 4, 10, 12],
        [9, 10, 11, 12],            [5, 7, 13, 15],
        [13, 14, 15, 16],           [6, 8, 14, 16]
    ]                            ]
    """
    print("B", B)
    # [B, H, W, D] --> [B*nh, ph, nw, pw*D]
    reshaped_fm = kops.reshape(x, (B * num_patches_h, patch_h, num_patches_w, patch_w * D)) # <--- Here
@sachinprasadhs sachinprasadhs added keras-team-review-pending Pending review by a Keras team member. type:Bug backend:jax labels Apr 17, 2024
@SamanehSaadat SamanehSaadat removed the keras-team-review-pending Pending review by a Keras team member. label Apr 18, 2024
@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 18, 2024
@veb-101
Copy link
Author

veb-101 commented Apr 18, 2024

Update: Unfortunately it's not working anymore on Colab as well.

colab link

Edit: Works properly with PyTorch and TensorFlow backend on Colab and local Windows machine.

@veb-101 veb-101 changed the title Unable to create model with Keras=3.2.1 and jax[cpu]=0.4.26 on Windows 11 machine. Unable to create model with Keras=3.2.1 and jax[cpu]=0.4.26 Apr 20, 2024
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@fchollet fchollet reopened this Apr 21, 2024
@veb-101
Copy link
Author

veb-101 commented May 3, 2024

I've simplified the Colab code further and added print statements for debugging. The code fails all the time on the reshape(...) function, and if I skip reshaping, then the transpose fails with the same error.

Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 32, 144).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:jax stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

5 participants