Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.export on JAX saves all model weights as constants embedded in the graph #19132

Open
martin-gorner opened this issue Feb 1, 2024 · 4 comments
Assignees
Labels
backend:jax stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@martin-gorner
Copy link
Contributor

Repro colab:

https://colab.research.google.com/drive/1QHg0zpFsJS6qfTDfBwts8KLule7B84RO?usp=sharing

Requested fix: when exporting a model through jax2tf, weights must be wrapped in tf.Variable before jax2tf is called.

Relevant jax2tf documentation:
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#saved-model-with-parameters:
"Some special care is needed to ensure that the model parameters are not embedded as constants in the graph and are instead saved separately as variables. This is useful for two reasons: the parameters could be very large and exceed the 2GB limits of the GraphDef part of the SavedModel, or you may want to fine-tune the model and change the value of the parameters."

@sachinprasadhs sachinprasadhs added backend:jax keras-team-review-pending Pending review by a Keras team member. type:Bug labels Feb 2, 2024
@nkovela1 nkovela1 assigned nkovela1 and unassigned SuryanarayanaY Feb 6, 2024
@SamanehSaadat SamanehSaadat removed the keras-team-review-pending Pending review by a Keras team member. label Feb 8, 2024
@fchollet
Copy link
Member

@nkovela1 this is fixed, right?

@nkovela1
Copy link
Contributor

@fchollet Yes, this is fixed. Closing the issue, thanks!

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Feb 27, 2024

I don't see any change in the repro Colab. It is still saving all variables as constants in the graph as far as I can tell.
And I did test with keras-nightly. See repro Colab.

@martin-gorner martin-gorner reopened this Feb 27, 2024
@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:jax stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

6 participants