Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init_by_shape does not work on pytrees #171

Closed
jheek opened this issue Apr 2, 2020 · 3 comments
Closed

init_by_shape does not work on pytrees #171

jheek opened this issue Apr 2, 2020 · 3 comments
Assignees

Comments

@jheek
Copy link
Member

jheek commented Apr 2, 2020

init_by_shape only supports a list of arrays as lazy arguments.
Instead it would be better to support arbitrary pytrees.

The easiest way to support this is by using the ShapeDtypeStruct in Jax similar to jax.eval_shape.

@marcvanzee marcvanzee self-assigned this Apr 22, 2020
@danielsuo
Copy link
Collaborator

Quick question: by arbitrary pytrees do you mean "support a list of pytrees as lazy arguments"?I.e., update _parse_spec flatten, turn into ShapeDtypeStruct, unflatten

@AlexeyG
Copy link
Collaborator

AlexeyG commented Jun 12, 2020

Marking as on-hold as this is likely to become much easier once the FLIPs land.

@avital
Copy link
Contributor

avital commented Dec 12, 2020

We no longer have `init_by_shape1 in Linen, so I'll close this issue.

@avital avital closed this as completed Dec 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants