Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added default params rng to .apply #3698

Merged
merged 1 commit into from
Feb 20, 2024
Merged

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Feb 15, 2024

Added default params rng to .apply.

Similarly to how you can get the same behavior by doing the following with .init:

v = model.init({'params': key1}, x)
v2 = model.init(key1, x)

This PR allows you to do the same with .apply:

out = model.apply(v, x, rngs={'params': key2})
out2 = model.apply(v, x, rngs=key2)

@chiamp chiamp self-assigned this Feb 15, 2024
@@ -1074,9 +1074,19 @@ def apply(
def wrapper(
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
rngs: PRNGKey | RNGSequences | None = None,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this fail Github CI for python 3.9?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our minimum Python version is still 3.9.

**kwargs,
) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]:
if rngs is not None:
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
if not _is_valid_rng(rngs):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are two separate functions: _is_valid_rng checks if the rng key rngs is valid, and _is_valid_rngs checks if the dictionary mapping rngs is valid (recursively)

@copybara-service copybara-service bot merged commit daf06ea into google:main Feb 20, 2024
19 checks passed
@chiamp chiamp deleted the apply_rng branch February 21, 2024 21:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants