Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions flax/optim/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
from .. import struct
from .base import OptimizerDef

import jax
import jax.numpy as jnp

import numpy as onp


Dtype = Any


@struct.dataclass
class _AdafactorHyperParams:
learning_rate: Optional[float]
Expand Down Expand Up @@ -68,7 +72,8 @@ def __init__(self,
weight_decay_rate: Optional[float] = None,
min_dim_size_to_factor: int = 128,
epsilon1: float = 1e-30,
epsilon2: float = 1e-3):
epsilon2: float = 1e-3,
dtype_momentum: Dtype = jnp.float32):
"""Constructor for the Adafactor optimizer.

Args:
Expand All @@ -91,11 +96,13 @@ def __init__(self,
are at least this size.
epsilon1: Regularization constant for squared gradient.
epsilon2: Regularization constant for parameter scale.
dtype_momentum: dtype of momentum buffers.
"""
hyper_params = _AdafactorHyperParams(
learning_rate, factored, multiply_by_parameter_scale,
beta1, decay_rate, step_offset, clipping_threshold,
weight_decay_rate, min_dim_size_to_factor, epsilon1, epsilon2)
self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum)
super().__init__(hyper_params)

@staticmethod
Expand Down Expand Up @@ -137,7 +144,7 @@ def init_param_state(self, param):
else:
state['v'] = jnp.zeros(param.shape, dtype=jnp.float32)
if self.hyper_params.beta1 is not None:
state['m'] = jnp.zeros(param.shape, dtype=jnp.float32)
state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum)
return _AdafactorParamState(**state)

def apply_param_gradient(self, step, hyper_params, param, state, grad):
Expand Down Expand Up @@ -192,7 +199,7 @@ def apply_param_gradient(self, step, hyper_params, param, state, grad):
if beta1 is not None:
new_m = beta1 * state.m + (1.0 - beta1) * subtrahend
subtrahend = new_m
updates['m'] = new_m
updates['m'] = new_m.astype(self.dtype_momentum)

if weight_decay_rate is not None:
new_param = (1.0 - weight_decay_rate) * param - subtrahend
Expand Down