Skip to content

Commit

Permalink
Update Returns section in gradient transformations' docstrings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 466908313
  • Loading branch information
hbq1 authored and OptaxDev committed Aug 11, 2022
1 parent 52daea3 commit ed20584
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
4 changes: 2 additions & 2 deletions optax/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def identity() -> GradientTransformation:
to be left unchanged when the updates are applied to them.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(_):
Expand Down Expand Up @@ -161,7 +161,7 @@ def set_to_zero() -> GradientTransformation:
parameters, unnecessary computations will in general be dropped.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
max_delta: The maximum absolute value for each element in the update.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand All @@ -63,7 +63,7 @@ def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
threshold: The maximum rms for the gradient of each param vector or matrix.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -98,7 +98,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
max_norm: The maximum global norm for an update.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -201,7 +201,7 @@ def adaptive_grad_clip(clipping: float,
eps: An epsilon term to prevent clipping of zero-initialized params.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def keep_params_nonnegative() -> base.GradientTransformation:
When params is negative the transformed update will move them to 0.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down
42 changes: 21 additions & 21 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def trace(
`None` then the `dtype` is inferred from `params` and `updates`.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)
Expand Down Expand Up @@ -143,7 +143,7 @@ def ema(
then the `dtype` is inferred from `params` and `updates`.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)
Expand Down Expand Up @@ -186,7 +186,7 @@ def scale_by_rss(
eps: A small floating point value to avoid zero denominator.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -228,7 +228,7 @@ def scale_by_rms(
initial_scale: initial value for second moment
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -268,7 +268,7 @@ def scale_by_stddev(
initial_scale: initial value for second moment
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -318,7 +318,7 @@ def scale_by_adam(
`None` then the `dtype is inferred from `params` and `updates`.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

mu_dtype = utils.canonicalize_dtype(mu_dtype)
Expand Down Expand Up @@ -360,7 +360,7 @@ def scale_by_adamax(
eps: term added to the denominator to improve numerical stability.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -393,7 +393,7 @@ def scale(
step_size: a scalar corresponding to a fixed scaling factor for updates.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand All @@ -420,7 +420,7 @@ def scale_by_param_block_norm(
min_scale: minimum scaling factor.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -450,7 +450,7 @@ def scale_by_param_block_rms(
min_scale: minimum scaling factor.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -495,7 +495,7 @@ def scale_by_belief(
gradient transformation (e.g. for meta-learning), this must be non-zero.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -545,7 +545,7 @@ def scale_by_yogi(
Only positive values are allowed.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -591,7 +591,7 @@ def scale_by_radam(
threshold: Threshold for variance tractability
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

ro_inf = 2./(1 - b2) - 1
Expand Down Expand Up @@ -643,7 +643,7 @@ def add_decayed_weights(
apply the transformation to, and `False` for those you want to skip.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -680,7 +680,7 @@ def scale_by_schedule(
the step_size to multiply the updates by.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -723,7 +723,7 @@ def scale_by_trust_ratio(
eps: additive constant added to the denominator for numerical stability.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -777,7 +777,7 @@ def add_noise(
seed: seed for random number generation.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -825,7 +825,7 @@ def apply_every(
k: emit non-zero gradients every k steps, otherwise accumulate them.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -863,7 +863,7 @@ def centralize() -> base.GradientTransformation:
[Yong et al, 2020](https://arxiv.org/abs/2004.01461)
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down Expand Up @@ -900,7 +900,7 @@ def scale_by_sm3(
eps: term added to the denominator to improve numerical stability.
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def zeros_for_dim(p):
Expand Down Expand Up @@ -966,7 +966,7 @@ def scale_by_optimistic_gradient(
beta: (float) coefficient for negative momentum
Returns:
An (init_fn, update_fn) tuple.
A `GradientTransformation` object.
"""

def init_fn(params):
Expand Down

0 comments on commit ed20584

Please sign in to comment.