Skip to content

Commit

Permalink
* Adding more extensive exponential decay schedule.
Browse files Browse the repository at this point in the history
* Adding reporting of learning rate to the Optax optimizers.

PiperOrigin-RevId: 543428869
Change-Id: Ic17beea2081868c2d29fc0e9a80d48a981d10cc7
  • Loading branch information
FermiNet Contributor authored and jsspencer committed Jul 18, 2023
1 parent d3acd19 commit e29145a
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Curvature blocks for FermiNet."""
from typing import Any, Mapping, Optional, Sequence, Set, Tuple
from typing import Any, Mapping, Sequence, Set, Tuple
import jax
import jax.numpy as jnp
import kfac_jax
Expand Down Expand Up @@ -56,8 +56,6 @@ def update_curvature_matrix_estimate(
ema_old: Numeric,
ema_new: Numeric,
batch_size: int,
pmap_axis_name: Optional[str],
sync: Array | bool = True,
) -> kfac_jax.TwoKroneckerFactored.State:
estimation_data = dict(**estimation_data)
x, = estimation_data["inputs"]
Expand All @@ -72,8 +70,6 @@ def update_curvature_matrix_estimate(
ema_old=ema_old,
ema_new=ema_new,
batch_size=batch_size,
pmap_axis_name=pmap_axis_name,
sync=sync,
)


Expand All @@ -96,7 +92,6 @@ def update_curvature_matrix_estimate(
ema_old: Numeric,
ema_new: Numeric,
batch_size: int,
pmap_axis_name: Optional[str],
) -> kfac_jax.TwoKroneckerFactored.State:
x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
Expand All @@ -111,9 +106,6 @@ def update_curvature_matrix_estimate(
state.inputs_factor.update(inputs_cov, ema_old, ema_new)
state.outputs_factor.update(outputs_cov, ema_old, ema_new)

state.inputs_factor.sync(pmap_axis_name)
state.outputs_factor.sync(pmap_axis_name)

return state

def _init(
Expand Down

0 comments on commit e29145a

Please sign in to comment.