-
Notifications
You must be signed in to change notification settings - Fork 771
Open
Description
It seems that LSTMCell keeps rngs in its state:
Line 137 in a8a192f
| self.rngs = rngs |
Is this intentional? Why?
I stumbled upon this because my recipe for checkpointing breaks when my model contains an LSTM:
import orbax.checkpoint as ocp
def savemodel(model, path):
_, state = nnx.split(model)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path, state)
Calling savemodel(model, path) throws:
TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.
This was surprising because I've been using that recipe before and never had a problem while using other non-LSTM modules.
conorhassan and alezana
Metadata
Metadata
Assignees
Labels
No labels