Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TypeError in contrib.reduce_on_plateau() when x64 is enabled #697

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions optax/contrib/reduce_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def update_fn(

# We're in cooldown, so reduce the counter and ignore any bad epochs
def in_cooldown():
new_plateau_count = 0
new_plateau_count = jnp.asarray(0, jnp.int32)
new_lr = state.lr
new_cooldown_counter = state.cooldown_counter - 1
return new_plateau_count, new_lr, new_cooldown_counter
Expand All @@ -108,7 +108,7 @@ def not_in_cooldown():
)
new_cooldown_counter = jnp.where(
curr_plateau_count == patience, cooldown, 0
)
).astype(jnp.int32)
return new_plateau_count, new_lr, new_cooldown_counter

new_plateau_count, new_lr, new_cooldown_counter = jax.lax.cond(
Expand Down
129 changes: 85 additions & 44 deletions optax/contrib/reduce_on_plateau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,82 +15,123 @@
"""Tests for `reduce_on_plateau.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax.numpy as jnp
from jax import config

from optax import contrib


class ReduceLROnPlateauTest(absltest.TestCase):
class ReduceLROnPlateauTest(parameterized.TestCase):
"""Test for reduce_on_plateau scheduler."""

def setUp(self):
super().setUp()
self.patience = 5
self.cooldown = 5
self.transform = contrib.reduce_on_plateau(
factor=0.1,
patience=self.patience,
threshold=1e-4,
cooldown=self.cooldown,
)
self.updates = {'params': jnp.array(1.0)} # dummy updates

def tearDown(self):
super().tearDown()
config.update('jax_enable_x64', False)

@parameterized.parameters(False, True)
def test_learning_rate_reduced_after_cooldown_period_is_over(
self, enable_x64
):
"""Test that learning rate is reduced after cooldown."""

def test_learning_rate_reduced_after_cooldown_period_is_over(self):
"""Test that learning rate is reduced after the cooldown period."""
# Enable float64 if requested
config.update('jax_enable_x64', enable_x64)

# Initialize the state
state = self.transform.init(self.updates['params'])

# Wait until patience runs out
for _ in range(self.patience + 1):
updates, state = self.transform.update(
updates=self.updates, state=state, loss=1.0
)

# Define a dummy update and extra_args
updates = {'params': jnp.array(1.0)}
patience = 5
cooldown = 5
# Apply the transformation to the updates and state
transform = contrib.reduce_on_plateau(patience=patience, cooldown=cooldown)
state = transform.init(updates['params'])
for _ in range(patience + 1):
updates, state = transform.update(updates=updates, state=state, loss=1.0)
# Check that learning rate is reduced
# we access the fields inside new_state using indices instead of attributes
# because otherwise pytype throws an error
lr, best_loss, plateau_count, cooldown_counter = state
chex.assert_trees_all_close(lr, 0.1)
chex.assert_trees_all_close(best_loss, 1.0)
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_counter, cooldown)
chex.assert_trees_all_close(cooldown_counter, self.cooldown)
chex.assert_trees_all_close(updates, {'params': jnp.array(0.1)})

_, state = transform.update(updates=updates, state=state, loss=1.0)
# One more step
_, state = self.transform.update(updates=updates, state=state, loss=1.0)

# Check that cooldown_counter is decremented
lr, best_loss, plateau_count, cooldown_counter = state
chex.assert_trees_all_close(lr, 0.1)
chex.assert_trees_all_close(best_loss, 1.0)
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_counter, cooldown - 1)
chex.assert_trees_all_close(cooldown_counter, self.cooldown - 1)

@parameterized.parameters(False, True)
def test_learning_rate_is_not_reduced(self, enable_x64):
"""Test that plateau_count resets after a new best_loss is found."""

# Enable float64 if requested
config.update('jax_enable_x64', enable_x64)

def test_learning_rate_is_not_reduced(self):
"""Test that plateau count resets after a new best loss is found."""
# State with positive plateau_count
state = contrib.ReduceLROnPlateauState(
best_loss=jnp.array(0.1, dtype=jnp.float32),
best_loss=jnp.array(1.0, dtype=jnp.float32),
plateau_count=jnp.array(3, dtype=jnp.int32),
lr=jnp.array(0.01, dtype=jnp.float32),
lr=jnp.array(0.1, dtype=jnp.float32),
cooldown_counter=jnp.array(0, dtype=jnp.int32),
)
# Define a dummy update and extra_args
updates = {'params': 1}
# Apply the transformation to the updates and state
transform = contrib.reduce_on_plateau(
factor=0.5, patience=5, threshold=1e-4, cooldown=5

# Update with better loss
_, new_state = self.transform.update(
updates=self.updates, state=state, loss=0.1
)
_, new_state = transform.update(updates=updates, state=state, loss=0.01)

# Check that plateau_count resets
lr, best_loss, plateau_count, _ = new_state
# Check that plateau count resets
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(lr, 0.01)
chex.assert_trees_all_close(best_loss, 0.01)
chex.assert_trees_all_close(lr, 0.1)
chex.assert_trees_all_close(best_loss, 0.1)

def test_learning_rate_not_reduced_during_cooldown(self):
@parameterized.parameters(False, True)
def test_learning_rate_not_reduced_during_cooldown(self, enable_x64):
"""Test that learning rate is not reduced during cooldown."""
# Define a state where cooldown_counter is positive

# Enable float64 if requested
config.update('jax_enable_x64', enable_x64)

# State with positive cooldown_counter
state = contrib.ReduceLROnPlateauState(
best_loss=jnp.array(0.1, dtype=jnp.float32),
plateau_count=jnp.array(4, dtype=jnp.int32),
lr=jnp.array(0.01, dtype=jnp.float32),
best_loss=jnp.array(1.0, dtype=jnp.float32),
plateau_count=jnp.array(0, dtype=jnp.int32),
lr=jnp.array(0.1, dtype=jnp.float32),
cooldown_counter=jnp.array(3, dtype=jnp.int32),
)
# Define a dummy update and extra_args
updates = {'params': 1}
# Apply the transformation to the updates and state
transform = contrib.reduce_on_plateau(
factor=0.5, patience=5, threshold=1e-4, cooldown=5

# Update with worse loss
_, new_state = self.transform.update(
updates=self.updates, state=state, loss=2.0
)
_, new_state = transform.update(updates=updates, state=state, loss=0.15)
# Check that learning rate is not reduced
lr, _, _, _ = new_state
chex.assert_trees_all_close(lr, 0.01)

# Check that learning rate is not reduced and
# plateau_count is not incremented
lr, best_loss, plateau_count, cooldown_counter = new_state
chex.assert_trees_all_close(lr, 0.1)
chex.assert_trees_all_close(best_loss, 1.0)
chex.assert_trees_all_close(plateau_count, 0)
chex.assert_trees_all_close(cooldown_counter, 2)


if __name__ == '__main__':
Expand Down
Loading