Skip to content

Commit

Permalink
feat: add normalize_by_update_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Jun 25, 2024
1 parent 8a3ee74 commit 6a6803e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
33 changes: 33 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,39 @@ def update_fn(
return base.GradientTransformationExtraArgs(_init_empty_state, update_fn)


class NormalizeByUpdateNormState(NamedTuple):
"""State for normalize_by_update_norm."""
scale: float
eps: float


def normalize_by_update_norm(
scale: float = 1.0, eps: float = 1e-6
) -> base.GradientTransformation:
"""
Scale by the inverse of the gradient norm.
Args:
scale: (float) scaling factor
eps: (float) jitter term to avoid dividing by 0
Returns:
A `GradientTransformation` object.
"""

def init_fn(params):
del params
return NormalizeByUpdateNormState(scale, eps)

def update_fn(updates, state, params=None):
del params
g_norm = (utils.global_norm(updates) + eps) / scale
updates = jtu.tree_map(lambda g: g / g_norm, updates)
return updates, state

return base.GradientTransformation(init_fn, update_fn)


### Legacy symbols to be removed. ###


Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def setUp(self):
('param_block_norm', transform.scale_by_param_block_norm),
('param_block_rms', transform.scale_by_param_block_rms),
('distance_over_gradients', transform.scale_by_distance_over_gradients),
('normalize_by_update_norm', transform.normalize_by_update_norm),
])
def test_scalers(self, scaler_constr):
params = self.init_params
Expand Down

0 comments on commit 6a6803e

Please sign in to comment.