-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UndefinedPoly error in jax2tf #5915
Comments
Shape polymorphism is not yet supported in jax2tf. We had it working a few months ago but we rolled that support back because there were some unsoundness bugs. We do want to support this again, but it requires some changes in JAX core. |
Thanks for the info @gnecula ! It might be worth a little note on the examples (unless I missed it). I'll add support only for static dimensions for now. |
I'll keep this issue open, for me to go through the examples and remove mentions of batch polymorphism. |
Is there an issue for batch polymorphism in general I can follow? |
I created issue #6080 |
Hey, I am trying to create a saved model with a variable batch dimension from jax. In my personal code I was getting a
UndefinedPoly
error with the named batch dimension at some point in the Flax module so I tried running the saved_model_main.py example with--serving_batch_size=-1
just to be sure but I get the same error:It seems that Jax is not able to broadcast to "named" dimensions. Creating a saved model with static dimension works just fine.
The text was updated successfully, but these errors were encountered: