Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GradClipByValue callback #315

Merged
merged 13 commits into from
Aug 31, 2021
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
### ensmallen ?.??.?: "???"
###### ????-??-??
* Add gradient value clipping and gradient norm scaling callback
([#315](https://github.com/mlpack/ensmallen/pull/315)).

* Remove superfluous CMake option to build the tests
([#313](https://github.com/mlpack/ensmallen/pull/313)).

Expand Down
87 changes: 85 additions & 2 deletions doc/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,97 @@ EarlyStopAtMinLoss cb(
// You could also, e.g., print the validation loss here to watch it converge.
return lrfValidation.Evaluate(coordinates);
});

arma::mat coordinates = lrfTrain.GetInitialPoint();
SMORMS3 smorms3;
smorms3.Optimize(lrfTrain, coordinates, cb);
```

</details>

### GradClipByNorm

One difficulty with optimization is that large parameter gradients can lead an
optimizer to update the parameters strongly into a region where the loss
function is much greater, effectively undoing much of the work done to get to
the current solution. Such large updates during the optimization can cause a
numerical overflow or underflow, often referred to as "exploding gradients." The
conradsnicta marked this conversation as resolved.
Show resolved Hide resolved
exploding gradient problem can be caused by: Choosing the wrong learning rate
which leads to huge updates in the gradients. Failing to scale a data set
leading to very large differences between data points. Applying a loss function
that computes very large error values.

A common answer to the exploding gradients problem is to change the derivative
of the error before applying the update step. One option is to clip the norm
`||g||` of the gradient `g` before a parameter update. So given the gradient,
and a maximum norm value, the callback normalizes the gradient so that its
L2-norm is less than or equal to the given maximum norm value.

#### Constructors

* `GradClipByNorm(`_`maxNorm`_`)`

#### Attributes

| **type** | **name** | **description** | **default** |
|----------|----------|-----------------|-------------|
| `double` | **`maxNorm`** | The maximum clipping value. | |

#### Examples:

<details open>
<summary>Click to collapse/expand example code.
</summary>

```c++
AdaDelta optimizer(1.0, 1, 0.99, 1e-8, 1000, 1e-9, true);

RosenbrockFunction f;
arma::mat coordinates = f.GetInitialPoint();
optimizer.Optimize(f, coordinates, GradClipByNorm(0.3));
```

### GradClipByValue

One difficulty with optimization is that large parameter gradients can lead an
optimizer to update the parameters strongly into a region where the loss
function is much greater, effectively undoing much of the work done to get to
the current solution. Such large updates during the optimization can cause a
numerical overflow or underflow, often referred to as "exploding gradients." The
conradsnicta marked this conversation as resolved.
Show resolved Hide resolved
exploding gradient problem can be caused by: Choosing the wrong learning rate
which leads to huge updates in the gradients. Failing to scale a data set
leading to very large differences between data points. Applying a loss function
that computes very large error values.

A common answer to the exploding gradients problem is to change the derivative
of the error before applying the update step. One option is to clip the
parameter gradient element-wise before a parameter update.

#### Constructors

* `GradClipByValue(`_`min, max`_`)`

#### Attributes

| **type** | **name** | **description** | **default** |
|----------|----------|-----------------|-------------|
| `double` | **`min`** | The minimum value to clip to. | |
| `double` | **`max`** | The maximum value to clip to. | |

#### Examples:

<details open>
<summary>Click to collapse/expand example code.
</summary>

```c++
AdaDelta optimizer(1.0, 1, 0.99, 1e-8, 1000, 1e-9, true);

RosenbrockFunction f;
arma::mat coordinates = f.GetInitialPoint();
optimizer.Optimize(f, coordinates, GradClipByValue(0, 1.3));
```

### PrintLoss

Callback that prints loss to stdout or a specified output stream.
Expand Down Expand Up @@ -208,7 +291,7 @@ optimizer.Optimize(f, coordinates, ProgressBar());

</details>

### Report
### Report

Callback that prints a optimizer report to stdout or a specified output stream.

Expand Down
2 changes: 2 additions & 0 deletions include/ensmallen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
// Callbacks.
#include "ensmallen_bits/callbacks/callbacks.hpp"
#include "ensmallen_bits/callbacks/early_stop_at_min_loss.hpp"
#include "ensmallen_bits/callbacks/grad_clip_by_norm.hpp"
#include "ensmallen_bits/callbacks/grad_clip_by_value.hpp"
#include "ensmallen_bits/callbacks/print_loss.hpp"
#include "ensmallen_bits/callbacks/progress_bar.hpp"
#include "ensmallen_bits/callbacks/query_front.hpp"
Expand Down
60 changes: 60 additions & 0 deletions include/ensmallen_bits/callbacks/grad_clip_by_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* @file grad_clip_by_norm.hpp
* @author Marcus Edel
*
* Clip the gradients by multiplying the unit vector of the gradients with the
* threshold.
*
* ensmallen is free software; you may redistribute it and/or modify it under
* the terms of the 3-clause BSD license. You should have received a copy of
* the 3-clause BSD license along with ensmallen. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef ENSMALLEN_CALLBACKS_GRAD_CLIP_BY_NORM_HPP
#define ENSMALLEN_CALLBACKS_GRAD_CLIP_BY_NORM_HPP

namespace ens {

/**
* Clip the gradients by multiplying the unit vector of the gradients with the
* threshold.
*/
class GradClipByNorm
{
public:
/**
* Set up the gradient clip by norm callback class with the maximum clipping
* value.
*
* @param maxNorm The maximum clipping value.
*/
GradClipByNorm(const double maxNorm) : maxNorm(maxNorm)
{ /* Nothing to do here. */ }

/**
* Callback function called at any call to Gradient().
*
* @param optimizer The optimizer used to update the function.
* @param function Function to optimize.
* @param coordinates Starting point.
* @param gradient Matrix that holds the gradient.
*/
template<typename OptimizerType, typename FunctionType, typename MatType>
void Gradient(OptimizerType& /* optimizer */,
FunctionType& /* function */,
const MatType& /* coordinates */,
MatType& gradient)
{
const double gradientNorm = arma::norm(gradient);
if (gradientNorm > maxNorm)
gradient = maxNorm * gradient / gradientNorm;
}

private:
//! The maximum clipping value for gradient clipping.
const double maxNorm;
};

} // namespace ens

#endif
60 changes: 60 additions & 0 deletions include/ensmallen_bits/callbacks/grad_clip_by_value.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* @file grad_clip_by_value.hpp
* @author Marcus Edel
*
* Clips the gradient to a specified min and max.
*
* ensmallen is free software; you may redistribute it and/or modify it under
* the terms of the 3-clause BSD license. You should have received a copy of
* the 3-clause BSD license along with ensmallen. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef ENSMALLEN_CALLBACKS_GRAD_CLIP_BY_VALUE_HPP
#define ENSMALLEN_CALLBACKS_GRAD_CLIP_BY_VALUE_HPP

namespace ens {

/**
* Clip the gradient to a specified min and max.
*/
class GradClipByValue
{
public:
/**
* Set up the gradient clip by value callback class with the min and max
* value.
*
* @param min The minimum value to clip to.
* @param max The maximum value to clip to.
*/
GradClipByValue(const double min, const double max) : lower(min), upper(max)
{ /* Nothing to do here. */ }

/**
* Callback function called at any call to Gradient().
*
* @param optimizer The optimizer used to update the function.
* @param function Function to optimize.
* @param coordinates Starting point.
* @param gradient Matrix that holds the gradient.
*/
template<typename OptimizerType, typename FunctionType, typename MatType>
void Gradient(OptimizerType& /* optimizer */,
FunctionType& /* function */,
const MatType& /* coordinates */,
MatType& gradient)
{
gradient = arma::clamp(gradient, lower, upper);
}

private:
//! The minimum value to clip to.
const double lower;

//! The maximum value to clip to.
const double upper;
};

} // namespace ens

#endif
56 changes: 50 additions & 6 deletions include/ensmallen_bits/callbacks/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ struct TypedForms
const MatType&,
const MatType&);

//! This is the form of a bool Gradient() callback method where the gradient
//! is modifiable.
template<typename CallbackType>
using GradientBoolModifiableForm =
void(CallbackType::*)(OptimizerType&,
FunctionType&,
const MatType&,
MatType&);

//! This is the form of a void Gradient() callback method.
template<typename CallbackType>
using GradientVoidForm =
Expand All @@ -96,6 +105,15 @@ struct TypedForms
const MatType&,
const MatType&);

//! This is the form of a void Gradient() callback method where the gradient
//! is modifiable.
template<typename CallbackType>
using GradientVoidModifiableForm =
void(CallbackType::*)(OptimizerType&,
FunctionType&,
const MatType&,
MatType&);

//! This is the form of a bool GradientConstraint() callback method.
template<typename CallbackType>
using GradientConstraintBoolForm =
Expand All @@ -105,6 +123,16 @@ struct TypedForms
const size_t,
const MatType&);

//! This is the form of a bool GradientConstraint() callback method where the
//! gradient is modifiable.
template<typename CallbackType>
using GradientConstraintBoolModifiableForm =
void(CallbackType::*)(OptimizerType&,
FunctionType&,
const MatType&,
const size_t,
MatType&);

//! This is the form of a void GradientConstraint() callback method.
template<typename CallbackType>
using GradientConstraintVoidForm =
Expand All @@ -114,6 +142,16 @@ struct TypedForms
const size_t,
const MatType&);

//! This is the form of a void GradientConstraint() callback method where the
//! gradient is modifiable.
template<typename CallbackType>
using GradientConstraintVoidModifiableForm =
void(CallbackType::*)(OptimizerType&,
FunctionType&,
const MatType&,
const size_t,
MatType&);

//! This is the form of a bool BeginOptimization() callback method.
template<typename CallbackType>
using BeginOptimizationBoolForm =
Expand Down Expand Up @@ -230,9 +268,9 @@ struct HasEvaluateSignature
{
const static bool value =
HasEvaluate<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType>::template EvaluateBoolForm>::value ||
FunctionType, MatType>::template EvaluateBoolForm>::value ||
HasEvaluate<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType>::template EvaluateVoidForm>::value;
FunctionType, MatType>::template EvaluateVoidForm>::value;
};

//! Utility struct, check if either void EvaluateConstraint() or
Expand All @@ -245,9 +283,9 @@ struct HasEvaluateConstraintSignature
{
const static bool value =
HasEvaluateConstraint<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType>::template EvaluateConstraintBoolForm>::value ||
FunctionType, MatType>::template EvaluateConstraintBoolForm>::value ||
HasEvaluateConstraint<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType>::template EvaluateConstraintVoidForm>::value;
FunctionType, MatType>::template EvaluateConstraintVoidForm>::value;
};

//! Utility struct, check if either void Gradient() or bool Gradient()
Expand All @@ -261,9 +299,15 @@ struct HasGradientSignature
{
const static bool value =
HasGradient<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType, Gradient>::template GradientBoolForm>::value ||
FunctionType, MatType, Gradient>::template GradientBoolForm>::value ||
HasGradient<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType,
Gradient>::template GradientBoolModifiableForm>::value ||
HasGradient<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType, Gradient>::template GradientVoidForm>::value ||
HasGradient<CallbackType, TypedForms<OptimizerType,
FunctionType, MatType, Gradient>::template GradientVoidForm>::value;
FunctionType, MatType,
Gradient>::template GradientVoidModifiableForm>::value;
};

//! Utility struct, check if either void GradientConstraint() or
Expand Down
Loading