You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The code functions correctly on Colab (where I installed the latest versions of Jax and Keras), so I assume it should also work well on Linux machines.
On the Windows machine, it operates as intended with the TensorFlow backend. However, when using the Jax backend, the model creation fails. Interestingly, initializing the custom layer works properly, and I can even perform a forward pass through it.
Testing layer initialization...
x.shape (1, 32, 32, 96) (1, 32, 32, 96)
B 1
x.shape (1, 32, 32, 96) (1, 32, 32, 96)
B 1
inputs.shape (1, 32, 32, 96)
out.shape (1, 32, 32, 96)
-------------------------------------------
-------------------------------------------
-------------------------------------------
Testing layer inside Model.
inputs.shape (None, 32, 32, 96)
x.shape (83, 32, 32, 96) (83, 32, 32, 96)
B 83
C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\layers\layer.py:1248: UserWarning: Layer 'MobileViTBlock-1' looks like it has unbuilt state, but Keras is not able to trace the layer `call()` in order to build it automatically. Possible causes:
1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 32, 96).''
warnings.warn(
C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\layers\layer.py:357: UserWarning: `build()` was called on layer 'MobileViTBlock-1', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(
x.shape (83, 32, 32, 96) (83, 32, 32, 96)
B 83
Traceback (most recent call last):
File "C:\Users\vaibh\OneDrive\Desktop\Work\other_work_mine\Mine\keras-vision\keras_vision\MobileViT_v1\minimal_example.py", line 276, in <module>
create_test_model()
File "C:\Users\vaibh\OneDrive\Desktop\Work\other_work_mine\Mine\keras-vision\keras_vision\MobileViT_v1\minimal_example.py", line 232, in create_test_model
mvitblk = MobileViT_v1_Block(
^^^^^^^^^^^^^^^^^^^
File "C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\vaibh\miniconda3\envs\keras_univ\Lib\site-packages\keras\src\backend\jax\core.py", line 253, in compute_output_spec
_, jax_out = jax.make_jaxpr(wrapped_fn, return_shape=True)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Exception encountered when calling MobileViT_v1_Block.call().
Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 32, 96).
Arguments received by MobileViT_v1_Block.call():
• args=('<KerasTensorshape=(None,32,32,96),dtype=float32,sparse=None,name=keras_tensor_4>',)
• kwargs=<class 'inspect._empty'>
Edit: Works properly with PyTorch and TensorFlow backend on Colab and local Windows machine.
veb-101
changed the title
Unable to create model with Keras=3.2.1 and jax[cpu]=0.4.26 on Windows 11 machine.
Unable to create model with Keras=3.2.1 and jax[cpu]=0.4.26
Apr 20, 2024
I've simplified the Colab code further and added print statements for debugging. The code fails all the time on the reshape(...) function, and if I skip reshaping, then the transpose fails with the same error.
Shapes must be 1D sequences of concrete values of integer type, got (None, 32, 32, 144).
Hello, I'm attempting to create the MobileViT model using the Keras 3 backend.
There are two main aspects:
Here is my code: Keras3_jax_windows.py
The code functions correctly on Colab (where I installed the latest versions of Jax and Keras), so I assume it should also work well on Linux machines.
On the Windows machine, it operates as intended with the TensorFlow backend. However, when using the Jax backend, the model creation fails. Interestingly, initializing the custom layer works properly, and I can even perform a forward pass through it.
Code snippet:
Output:
I know the code fails at this point:
The text was updated successfully, but these errors were encountered: