Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Returns section in gradient transformations' docstrings. #388

Merged
merged 1 commit into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions optax/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def identity() -> GradientTransformation:
to be left unchanged when the updates are applied to them.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(_):
Expand Down Expand Up @@ -161,7 +161,7 @@ def set_to_zero() -> GradientTransformation:
parameters, unnecessary computations will in general be dropped.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
max_delta: The maximum absolute value for each element in the update.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand All @@ -63,7 +63,7 @@ def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
threshold: The maximum rms for the gradient of each param vector or matrix.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -98,7 +98,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
max_norm: The maximum global norm for an update.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -201,7 +201,7 @@ def adaptive_grad_clip(clipping: float,
eps: An epsilon term to prevent clipping of zero-initialized params.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def keep_params_nonnegative() -> base.GradientTransformation:
When params is negative the transformed update will move them to 0.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down
42 changes: 21 additions & 21 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def trace(
`None` then the `dtype` is inferred from `params` and `updates`.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)
Expand Down Expand Up @@ -143,7 +143,7 @@ def ema(
then the `dtype` is inferred from `params` and `updates`.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)
Expand Down Expand Up @@ -186,7 +186,7 @@ def scale_by_rss(
eps: A small floating point value to avoid zero denominator.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -228,7 +228,7 @@ def scale_by_rms(
initial_scale: initial value for second moment

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -268,7 +268,7 @@ def scale_by_stddev(
initial_scale: initial value for second moment

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -318,7 +318,7 @@ def scale_by_adam(
`None` then the `dtype is inferred from `params` and `updates`.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

mu_dtype = utils.canonicalize_dtype(mu_dtype)
Expand Down Expand Up @@ -360,7 +360,7 @@ def scale_by_adamax(
eps: term added to the denominator to improve numerical stability.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -393,7 +393,7 @@ def scale(
step_size: a scalar corresponding to a fixed scaling factor for updates.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand All @@ -420,7 +420,7 @@ def scale_by_param_block_norm(
min_scale: minimum scaling factor.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -450,7 +450,7 @@ def scale_by_param_block_rms(
min_scale: minimum scaling factor.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -495,7 +495,7 @@ def scale_by_belief(
gradient transformation (e.g. for meta-learning), this must be non-zero.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -545,7 +545,7 @@ def scale_by_yogi(
Only positive values are allowed.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -591,7 +591,7 @@ def scale_by_radam(
threshold: Threshold for variance tractability

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

ro_inf = 2./(1 - b2) - 1
Expand Down Expand Up @@ -643,7 +643,7 @@ def add_decayed_weights(
apply the transformation to, and `False` for those you want to skip.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -680,7 +680,7 @@ def scale_by_schedule(
the step_size to multiply the updates by.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -723,7 +723,7 @@ def scale_by_trust_ratio(
eps: additive constant added to the denominator for numerical stability.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -777,7 +777,7 @@ def add_noise(
seed: seed for random number generation.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -825,7 +825,7 @@ def apply_every(
k: emit non-zero gradients every k steps, otherwise accumulate them.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -863,7 +863,7 @@ def centralize() -> base.GradientTransformation:
[Yong et al, 2020](https://arxiv.org/abs/2004.01461)

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -900,7 +900,7 @@ def scale_by_sm3(
eps: term added to the denominator to improve numerical stability.

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def zeros_for_dim(p):
Expand Down Expand Up @@ -966,7 +966,7 @@ def scale_by_optimistic_gradient(
beta: (float) coefficient for negative momentum

Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down