Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atol option to contrib.reduce_on_plateau() #698

Conversation

stefanocortinovis
Copy link
Contributor

@stefanocortinovis stefanocortinovis commented Jan 9, 2024

This PR swaps the threshold argument in the current implementation of reduce_on_plateau() with the two arguments rtol and atol. The change is added to reduce_on_plateau.py and reduce_on_plateau.ipynb where needed to retain the current behaviour.

The behaviour in the current implementation corresponds to the case atol = 0 and rtol > 0 in the new one, and is kept as the default. The addition of atol allows to set an absolute tolerance for measuring a new best loss, or to mix relative and absolute tolerances by choosing a value greater than zero for both rtol and atol.

In order to implement this change, I've modified the inequality has_improved = jnp.where((loss / state.best_loss - 1) < -threshold, 1, 0) to has_improved = jnp.where(loss < (1 - rtol) * state.best_loss - atol, 1, 0). Notice that moving the term state.best_loss from the denominator of the left-hand side in the first inequality to the right-hand side in the second one has the added benefit of correctly handing losses that can attain negative values (e.g. negative ELBOs with continuous distributions) when rtol = 0 and atol > 0.

@fabianp
Copy link
Member

fabianp commented Jan 11, 2024

Thanks @stefanocortinovis ! Having a relative tolerance makes a lot of sense to me. The only downside I see is that it no longer follows the pytorch implementation (not that we have to), so I think that's fine.

Any opinion on this @vz415 @vroulet since you contributed the reduce_on_plateau implementation and review respectively?

@vz415
Copy link
Contributor

vz415 commented Jan 12, 2024

Nice catch! This does handle negative losses correctly but I'd just suggest adding having an assert statement or some other safety measure to ensure users set rtol > 0 and atol > 0. Hate to be that guy but it might also help to have a warning if a user's loss is negative and rtol = 0 and atol > 0 isn't set.

@stefanocortinovis
Copy link
Contributor Author

Thanks for the comments! I added validation checks for atol and rtol. However, I haven't seen runtime warnings raised from within the update_fn of any other GradientTransformation. How would you suggest to proceed?

@fabianp
Copy link
Member

fabianp commented Jan 14, 2024

I suggest we use the warnings module in the Python standard library (https://docs.python.org/3/library/warnings.html) since that seems to be what jax uses (see for instance https://github.com/google/jax/blob/f539187c053bf1819f05d0f8e9e66e45da2af17b/jax/_src/array.py#L456)

@stefanocortinovis
Copy link
Contributor Author

Thanks for the suggestion @fabianp. However, I'm still not convinced raising warnings based on the value of loss would be a good idea.

In theory, raising such a warning would involve adding something like

if loss < 0.0 and some_other_condition:
    warnings.warn('some warning')

to the update_fn of reduce_on_plateau.

Howeover, update_fn will usually be called within a jitted block. Something like:

@jit
def make_step(params, transform_state):
    updates = {"params": 1.0}
    loss = jnp.asarray(-1.0)
    
    updates, _ = transform.update(updates=updates, state=transform_state, loss=loss)
    params = optax.apply_updates(params, updates)

    return params, loss

params = {"params": 2.0}
transform = contrib.reduce_on_plateau()
transform_state = transform.init(params)

make_step(params, transform_state)

Hence, if update_fn involves warnings based on the sign of loss, make_step will raise the usual TracerBoolConversionError because loss is a traced array.

Am I missing something here?

@fabianp
Copy link
Member

fabianp commented Jan 15, 2024

You're absolutely right @stefanocortinovis , I wasn't thinking about jitting.

@copybara-service copybara-service bot merged commit 1decd3e into google-deepmind:main Jan 20, 2024
6 checks passed
@fabianp
Copy link
Member

fabianp commented Jan 20, 2024

Thanks for the contribution @stefanocortinovis !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants