From ed20584bd7c6419baf94b74fdb5ec93b41856042 Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Thu, 11 Aug 2022 03:10:13 -0700 Subject: [PATCH] Update Returns section in gradient transformations' docstrings. PiperOrigin-RevId: 466908313 --- optax/_src/base.py | 4 ++-- optax/_src/clipping.py | 8 ++++---- optax/_src/constrain.py | 2 +- optax/_src/transform.py | 42 ++++++++++++++++++++--------------------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/optax/_src/base.py b/optax/_src/base.py index 08dec9b4e..97ff04a67 100644 --- a/optax/_src/base.py +++ b/optax/_src/base.py @@ -131,7 +131,7 @@ def identity() -> GradientTransformation: to be left unchanged when the updates are applied to them. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(_): @@ -161,7 +161,7 @@ def set_to_zero() -> GradientTransformation: parameters, unnecessary computations will in general be dropped. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): diff --git a/optax/_src/clipping.py b/optax/_src/clipping.py index e3223a3f3..7778893e2 100644 --- a/optax/_src/clipping.py +++ b/optax/_src/clipping.py @@ -37,7 +37,7 @@ def clip(max_delta: chex.Numeric) -> base.GradientTransformation: max_delta: The maximum absolute value for each element in the update. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -63,7 +63,7 @@ def clip_by_block_rms(threshold: float) -> base.GradientTransformation: threshold: The maximum rms for the gradient of each param vector or matrix. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -98,7 +98,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: max_norm: The maximum global norm for an update. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -201,7 +201,7 @@ def adaptive_grad_clip(clipping: float, eps: An epsilon term to prevent clipping of zero-initialized params. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): diff --git a/optax/_src/constrain.py b/optax/_src/constrain.py index 8a6386bb0..e98d12abb 100644 --- a/optax/_src/constrain.py +++ b/optax/_src/constrain.py @@ -38,7 +38,7 @@ def keep_params_nonnegative() -> base.GradientTransformation: When params is negative the transformed update will move them to 0. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): diff --git a/optax/_src/transform.py b/optax/_src/transform.py index b4c3b21bb..ee6e22827 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -55,7 +55,7 @@ def trace( `None` then the `dtype` is inferred from `params` and `updates`. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) @@ -143,7 +143,7 @@ def ema( then the `dtype` is inferred from `params` and `updates`. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) @@ -186,7 +186,7 @@ def scale_by_rss( eps: A small floating point value to avoid zero denominator. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -228,7 +228,7 @@ def scale_by_rms( initial_scale: initial value for second moment Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -268,7 +268,7 @@ def scale_by_stddev( initial_scale: initial value for second moment Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -318,7 +318,7 @@ def scale_by_adam( `None` then the `dtype is inferred from `params` and `updates`. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) @@ -360,7 +360,7 @@ def scale_by_adamax( eps: term added to the denominator to improve numerical stability. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -393,7 +393,7 @@ def scale( step_size: a scalar corresponding to a fixed scaling factor for updates. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -420,7 +420,7 @@ def scale_by_param_block_norm( min_scale: minimum scaling factor. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -450,7 +450,7 @@ def scale_by_param_block_rms( min_scale: minimum scaling factor. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -495,7 +495,7 @@ def scale_by_belief( gradient transformation (e.g. for meta-learning), this must be non-zero. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -545,7 +545,7 @@ def scale_by_yogi( Only positive values are allowed. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -591,7 +591,7 @@ def scale_by_radam( threshold: Threshold for variance tractability Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ ro_inf = 2./(1 - b2) - 1 @@ -643,7 +643,7 @@ def add_decayed_weights( apply the transformation to, and `False` for those you want to skip. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -680,7 +680,7 @@ def scale_by_schedule( the step_size to multiply the updates by. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -723,7 +723,7 @@ def scale_by_trust_ratio( eps: additive constant added to the denominator for numerical stability. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -777,7 +777,7 @@ def add_noise( seed: seed for random number generation. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -825,7 +825,7 @@ def apply_every( k: emit non-zero gradients every k steps, otherwise accumulate them. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -863,7 +863,7 @@ def centralize() -> base.GradientTransformation: [Yong et al, 2020](https://arxiv.org/abs/2004.01461) Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params): @@ -900,7 +900,7 @@ def scale_by_sm3( eps: term added to the denominator to improve numerical stability. Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def zeros_for_dim(p): @@ -966,7 +966,7 @@ def scale_by_optimistic_gradient( beta: (float) coefficient for negative momentum Returns: - An (init_fn, update_fn) tuple. + A `GradientTransformation` object. """ def init_fn(params):