From e14840ea849bfc10bbaf0d32057558dd43aef9d2 Mon Sep 17 00:00:00 2001 From: Alexander Kolesnikov Date: Tue, 19 Jan 2021 13:55:16 -0800 Subject: [PATCH] Add dtype option for momentum buffers to adafactor. PiperOrigin-RevId: 352647669 --- flax/optim/adafactor.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/flax/optim/adafactor.py b/flax/optim/adafactor.py index ed56fea13..fff3ef1e8 100644 --- a/flax/optim/adafactor.py +++ b/flax/optim/adafactor.py @@ -23,11 +23,15 @@ from .. import struct from .base import OptimizerDef +import jax import jax.numpy as jnp import numpy as onp +Dtype = Any + + @struct.dataclass class _AdafactorHyperParams: learning_rate: Optional[float] @@ -68,7 +72,8 @@ def __init__(self, weight_decay_rate: Optional[float] = None, min_dim_size_to_factor: int = 128, epsilon1: float = 1e-30, - epsilon2: float = 1e-3): + epsilon2: float = 1e-3, + dtype_momentum: Dtype = jnp.float32): """Constructor for the Adafactor optimizer. Args: @@ -91,11 +96,13 @@ def __init__(self, are at least this size. epsilon1: Regularization constant for squared gradient. epsilon2: Regularization constant for parameter scale. + dtype_momentum: dtype of momentum buffers. """ hyper_params = _AdafactorHyperParams( learning_rate, factored, multiply_by_parameter_scale, beta1, decay_rate, step_offset, clipping_threshold, weight_decay_rate, min_dim_size_to_factor, epsilon1, epsilon2) + self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum) super().__init__(hyper_params) @staticmethod @@ -137,7 +144,7 @@ def init_param_state(self, param): else: state['v'] = jnp.zeros(param.shape, dtype=jnp.float32) if self.hyper_params.beta1 is not None: - state['m'] = jnp.zeros(param.shape, dtype=jnp.float32) + state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum) return _AdafactorParamState(**state) def apply_param_gradient(self, step, hyper_params, param, state, grad): @@ -192,7 +199,7 @@ def apply_param_gradient(self, step, hyper_params, param, state, grad): if beta1 is not None: new_m = beta1 * state.m + (1.0 - beta1) * subtrahend subtrahend = new_m - updates['m'] = new_m + updates['m'] = new_m.astype(self.dtype_momentum) if weight_decay_rate is not None: new_param = (1.0 - weight_decay_rate) * param - subtrahend