You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Quick question: by arbitrary pytrees do you mean "support a list of pytrees as lazy arguments"?I.e., update _parse_spec flatten, turn into ShapeDtypeStruct, unflatten
init_by_shape
only supports a list of arrays as lazy arguments.Instead it would be better to support arbitrary pytrees.
The easiest way to support this is by using the ShapeDtypeStruct in Jax similar to
jax.eval_shape
.The text was updated successfully, but these errors were encountered: