Join GitHub today
GitHub is home to over 50 million developers working together to host and review code, manage projects, and build software together.
Sign upEarly stopping at loss min on validation set #165
Conversation
include/ensmallen_bits/callbacks/early_stop_at_min_loss_on_validation.hpp
Outdated
Show resolved
Hide resolved
include/ensmallen_bits/callbacks/early_stop_at_min_loss_on_validation.hpp
Outdated
Show resolved
Hide resolved
| * Early stopping to terminate the optimization process early if the loss stops | ||
| * decreasing. | ||
| */ | ||
| class EarlyStopAtMinLossOnValidation |
zoq
Feb 17, 2020
Member
I guess instead of writing another callback class, we could also add another constructor to the existing EarlyStopAtMinLoss class, let me know what you think, might avoid some code duplication.
I guess instead of writing another callback class, we could also add another constructor to the existing EarlyStopAtMinLoss class, let me know what you think, might avoid some code duplication.
shrit
Feb 17, 2020
Author
Member
I agree I prefer one class with two constructors, the only reason for this prototype is our last discussion on the related issue. I will merge the two classes and push the code.
I agree I prefer one class with two constructors, the only reason for this prototype is our last discussion on the related issue. I will merge the two classes and push the code.
zoq
Feb 18, 2020
Member
Right, I guess that changed a bit because it's now part of ensmallen.
Right, I guess that changed a bit because it's now part of ensmallen.
|
I only overloaded the |
| { | ||
| return static_cast<FunctionType*>( | ||
| static_cast<Function<FunctionType, | ||
| MatType, GradType>*>(this))->Evaluate(coordinates); |
zoq
Mar 2, 2020
Member
I think this is one of the reasons why @rcurtin proposed to implement this in mlpack. I'm wondering if we can put this into the callback itself, instead of extending the Function Type interface, which we like to keep to a minimum.
I think this is one of the reasons why @rcurtin proposed to implement this in mlpack. I'm wondering if we can put this into the callback itself, instead of extending the Function Type interface, which we like to keep to a minimum.
shrit
Mar 2, 2020
•
Author
Member
@zoq I agree. However, I tried to implement it first in mlpack, but in mlpack case, we will have to change the architecture of this callback.
I think it is the same, even if we are in mlpack we will have to add_evaluate in ensmallen. If we had a close look, it seems to me that EndEpoch() requires FunctionType& function in our case, which is ens::Function. Suppose ens::Function</*mlpack neural network model should goes here*/>::Evaluate(arma::mat, arma::mat) interface does not exist. Can you give me more details about possible implementation in mlpack?
Otherwise, I can try to move the overloaded interface in the callbacks.
@zoq I agree. However, I tried to implement it first in mlpack, but in mlpack case, we will have to change the architecture of this callback.
I think it is the same, even if we are in mlpack we will have to add_evaluate in ensmallen. If we had a close look, it seems to me that EndEpoch() requires FunctionType& function in our case, which is ens::Function. Suppose ens::Function</*mlpack neural network model should goes here*/>::Evaluate(arma::mat, arma::mat) interface does not exist. Can you give me more details about possible implementation in mlpack?
Otherwise, I can try to move the overloaded interface in the callbacks.
rcurtin
Apr 2, 2020
•
Member
I think that maybe the best way to put this together would be to implement a callback that has structure a little bit like this:
class EarlyStoppingOnValidationSet // maybe can be named better?
{
private:
FFN<>& network;
arma::mat& validationPredictors;
arma::mat& validationResponses;
double lastLoss; // initialize to DBL_MAX?
public:
template<typename OptimizerType, typename FunctionType, typename MatType>
bool EndEpoch(OptimizerType& /* optimizer */,
FunctionType& /* function */,
const MatType& /* coordinates */,
const size_t /* epoch */,
const double /* objective */)
{
// The FFN we are holding internally should be the same one as what is being optimized.
const double validationLoss = ffn->Evaluate(validationPredictors, validationResponses);
if (validationLoss > lastLoss)
return true; // or do we return false to terminate early? I can't remember
return false;
}
};
I've omitted some functions there, but maybe that's a decent enough idea? It would definitely have to be in the mlpack codebase since it depends on FFN. Let me know what you think---and sorry it took so long to get to this. 👍
I think that maybe the best way to put this together would be to implement a callback that has structure a little bit like this:
class EarlyStoppingOnValidationSet // maybe can be named better?
{
private:
FFN<>& network;
arma::mat& validationPredictors;
arma::mat& validationResponses;
double lastLoss; // initialize to DBL_MAX?
public:
template<typename OptimizerType, typename FunctionType, typename MatType>
bool EndEpoch(OptimizerType& /* optimizer */,
FunctionType& /* function */,
const MatType& /* coordinates */,
const size_t /* epoch */,
const double /* objective */)
{
// The FFN we are holding internally should be the same one as what is being optimized.
const double validationLoss = ffn->Evaluate(validationPredictors, validationResponses);
if (validationLoss > lastLoss)
return true; // or do we return false to terminate early? I can't remember
return false;
}
};
I've omitted some functions there, but maybe that's a decent enough idea? It would definitely have to be in the mlpack codebase since it depends on FFN. Let me know what you think---and sorry it took so long to get to this.
shrit
Apr 2, 2020
Author
Member
Now, I understand the idea related to the FFN reference. I will implement it and see if it works.
Thanks. 👍
Now, I understand the idea related to the FFN reference. I will implement it and see if it works.
Thanks.
|
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! |
|
Keep open |
|
@zoq can you have a look on this please. Thanks |
| * @param patienceIn The number of epochs to wait after the minimum loss has | ||
| * been reached or no improvement has been made (Default: 10). | ||
| */ | ||
| EarlyStopAtMinLoss<AnnType>(AnnType& network, |
zoq
May 7, 2020
Member
I guess since this can be used with other methods as well, we should rename the parameter, what about function?
I guess since this can be used with other methods as well, we should rename the parameter, what about function?
rcurtin
May 7, 2020
Member
Right, seems like maybe this would be a perfect use case for a std::function? We could be really general about it. The user can just write a lambda that, when called, evaluates the loss on the given parameters...
// assume we built this network somehow and we want to optimize it...
FFN<...> network;
// we have some test set and responses...
arma::mat testSet;
arma::mat testResponses;
EarlyStopAtMinLoss e(
[network, testSet, testResponses](const arma::mat& /* parameters */) -> double
{
return network.Evaluate(testSet, testResponses);
});
network.Train(trainSet, trainResponses, e, ...);
That's just a sketch, but that's sufficiently general that the user can do whatever they want with it, with the easiest example being using a validation/test set to compute a loss. I guess, from my end, I want to make sure that the callback isn't specific to the machine learning setting (since ensmallen is more general, it's just generic optimization).
Right, seems like maybe this would be a perfect use case for a std::function? We could be really general about it. The user can just write a lambda that, when called, evaluates the loss on the given parameters...
// assume we built this network somehow and we want to optimize it...
FFN<...> network;
// we have some test set and responses...
arma::mat testSet;
arma::mat testResponses;
EarlyStopAtMinLoss e(
[network, testSet, testResponses](const arma::mat& /* parameters */) -> double
{
return network.Evaluate(testSet, testResponses);
});
network.Train(trainSet, trainResponses, e, ...);
That's just a sketch, but that's sufficiently general that the user can do whatever they want with it, with the easiest example being using a validation/test set to compute a loss. I guess, from my end, I want to make sure that the callback isn't specific to the machine learning setting (since ensmallen is more general, it's just generic optimization).
zoq
May 7, 2020
Member
I like the idea.
I like the idea.
rcurtin
May 26, 2020
Member
@shrit what do you think, would it be better to take any std::function here instead of requiring something that has an Evaluate() method?
@shrit what do you think, would it be better to take any std::function here instead of requiring something that has an Evaluate() method?
shrit
May 28, 2020
Author
Member
@rcurtin Maybe I can add another constructor that takes an std::function, at the same time the user might do anything in std::function and not only Evaluate
I am not sure I can replace everything by std::function, since the network.Evaluate need to be called at the end of each epoch. right? I am not sure
@rcurtin Maybe I can add another constructor that takes an std::function, at the same time the user might do anything in std::function and not only Evaluate
I am not sure I can replace everything by std::function, since the network.Evaluate need to be called at the end of each epoch. right? I am not sure
rcurtin
May 28, 2020
Member
Yeah, adding another constructor could be a nice thing to do, so that the technique is more general. I guess, my hope is not to link the mlpack and ensmallen codebases too closely here, so that people can use this callback in non-mlpack situations. If the user passes a lambda too, that actually makes the implementation easier, since EarlyStopAtMinLoss doesn't have to then track the validation set and responses (it will be implicitly held in the given lambda/std::function).
Yeah, adding another constructor could be a nice thing to do, so that the technique is more general. I guess, my hope is not to link the mlpack and ensmallen codebases too closely here, so that people can use this callback in non-mlpack situations. If the user passes a lambda too, that actually makes the implementation easier, since EarlyStopAtMinLoss doesn't have to then track the validation set and responses (it will be implicitly held in the given lambda/std::function).
| void EarlyStopCallbacksFunctionTest(OptimizerType& optimizer) | ||
| { | ||
| arma::mat data, testData, shuffledData; | ||
| arma::Row<size_t> responses, testResponses, shuffledResponses; | ||
|
|
||
| LogisticRegressionTestData(data, testData, shuffledData, | ||
| responses, testResponses, shuffledResponses); | ||
|
|
||
| LogisticRegression<> lr(shuffledData, shuffledResponses, 0.5); | ||
| arma::mat coordinates = lr.GetInitialPoint(); | ||
| LogisticRegressionValidationFunction lrValidation(lr, coordinates); | ||
|
|
||
| EarlyStopAtMinLoss<LogisticRegressionValidationFunction, | ||
| arma::mat, arma::Row<size_t>> cb( | ||
| lrValidation, testData, testResponses); | ||
|
|
||
| optimizer.Optimize(lr, coordinates, cb); | ||
| } | ||
|
|
||
| TEST_CASE("EarlyStopAtMinLossCallbackTest", "[CallbacksTest]") | ||
| { | ||
| SMORMS3 smorms3; |
shrit
May 9, 2020
Author
Member
@zoq I have added the function as you have proposed, but I am not sure that we are stopping the optimization as we were doing before
@zoq I have added the function as you have proposed, but I am not sure that we are stopping the optimization as we were doing before
| // SGDTestFunction f; | ||
| // arma::mat coordinates = f.GetInitialPoint(); | ||
|
|
||
| // // Instantiate the optimizer with a number of iterations that will take a | ||
| // // long time to finish. | ||
| // StandardSGD s(0.0003, 1, 2000000000, -10); | ||
| // s.ExactObjective() = true; | ||
|
|
||
| // // The optimization process should return in one second. | ||
| // const double result = s.Optimize(f, coordinates, EarlyStopAtMinLoss(100)); | ||
|
|
||
| // REQUIRE(result == Approx(-1.0).epsilon(0.0005)); | ||
| // REQUIRE(coordinates(0) == Approx(0.0).margin(1e-3)); | ||
| // REQUIRE(coordinates(1) == Approx(0.0).margin(1e-7)); | ||
| // REQUIRE(coordinates(2) == Approx(0.0).margin(1e-7)); |
shrit
May 9, 2020
Author
Member
I have commented this one, if it is no longer required I will remove it
I have commented this one, if it is no longer required I will remove it
rcurtin
Jun 30, 2020
Member
I think that we need to keep this test---I don't see a reason to remove it. Maybe I overlooked something?
I think that we need to keep this test---I don't see a reason to remove it. Maybe I overlooked something?
|
@shrit can I ask what the status of this is? When I look at the diff, it seems like the code is specific to mlpack neural networks. The constructor you added is nice and I think that is the best interface we can give, but a couple comments:
Let me know what you think; I hope the comments here are helpful. It would be really great to incorporate this into a new ensmallen release. :) |
|
@rcurtin I have applied the changes, your comments were very helpful, Thanks. |
...allen_bits/problems/logistic_regression_function_validation_function.hpp
Outdated
Show resolved
Hide resolved
| // SGDTestFunction f; | ||
| // arma::mat coordinates = f.GetInitialPoint(); | ||
|
|
||
| // // Instantiate the optimizer with a number of iterations that will take a | ||
| // // long time to finish. | ||
| // StandardSGD s(0.0003, 1, 2000000000, -10); | ||
| // s.ExactObjective() = true; | ||
|
|
||
| // // The optimization process should return in one second. | ||
| // const double result = s.Optimize(f, coordinates, EarlyStopAtMinLoss(100)); | ||
|
|
||
| // REQUIRE(result == Approx(-1.0).epsilon(0.0005)); | ||
| // REQUIRE(coordinates(0) == Approx(0.0).margin(1e-3)); | ||
| // REQUIRE(coordinates(1) == Approx(0.0).margin(1e-7)); | ||
| // REQUIRE(coordinates(2) == Approx(0.0).margin(1e-7)); |
rcurtin
Jun 30, 2020
Member
I think that we need to keep this test---I don't see a reason to remove it. Maybe I overlooked something?
I think that we need to keep this test---I don't see a reason to remove it. Maybe I overlooked something?
| { | ||
| SMORMS3 smorms3; | ||
| EarlyStopCallbacksFunctionTest(smorms3); | ||
| } |
rcurtin
Jun 30, 2020
Member
It would probably be good if we could also add a test that makes sure that the given lambda will terminate the optimization. I would suggest this... use, e.g., the Rosenbrock function with a custom lambda callback that terminates whenever any of the coordinates has an absolute value of less than 10. Then once the optimization terminates, make sure that we are nowhere near the minimum, and that at least one of the coordinates has an absolute value of less than 10. Let me know if I can clarify that idea---it's not the only possible idea, of course.
It would probably be good if we could also add a test that makes sure that the given lambda will terminate the optimization. I would suggest this... use, e.g., the Rosenbrock function with a custom lambda callback that terminates whenever any of the coordinates has an absolute value of less than 10. Then once the optimization terminates, make sure that we are nowhere near the minimum, and that at least one of the coordinates has an absolute value of less than 10. Let me know if I can clarify that idea---it's not the only possible idea, of course.
shrit
Jul 1, 2020
•
Author
Member
@rcurtin if you have a code sample of the idea, it would be great, I am truly not expert in tests 👍 .
@rcurtin if you have a code sample of the idea, it would be great, I am truly not expert in tests
rcurtin
Jul 2, 2020
Member
Ok, here is a rough outline. It is not perfect. In fact it may not even compile. But hopefully it will be enough to get the idea across.
// Maybe the name could be shorter...
TEST_CASE("EarlyStopAtMinLossWithCustomLambdaTest", "[CallbacksTest]")
{
// Use the 50-dimensional Rosenbrock function.
GeneralizedRosenbrockFunction f(50);
// Start at some really large point.
arma::mat coordinates = f.GetInitialPoint();
f.fill(100.0);
EarlyStopAtMinLoss<arma::mat> cb(
[&](const arma::mat& coordinates)
{
// Terminate if any coordinate has a value less than 10.
return arma::any(arma::abs(coordinates) < 10.0);
});
SMORMS3 smorms3;
smorm3.Optimize(f, coordinates, cb);
// Make sure that we did not get to the optimum.
for (size_t i = 0; i < coordinates.n_elem; ++i)
REQUIRE(std::abs(coordinates[i]) >= 3.0);
}
Anyway, I think that will work. 👍
Ok, here is a rough outline. It is not perfect. In fact it may not even compile. But hopefully it will be enough to get the idea across.
// Maybe the name could be shorter...
TEST_CASE("EarlyStopAtMinLossWithCustomLambdaTest", "[CallbacksTest]")
{
// Use the 50-dimensional Rosenbrock function.
GeneralizedRosenbrockFunction f(50);
// Start at some really large point.
arma::mat coordinates = f.GetInitialPoint();
f.fill(100.0);
EarlyStopAtMinLoss<arma::mat> cb(
[&](const arma::mat& coordinates)
{
// Terminate if any coordinate has a value less than 10.
return arma::any(arma::abs(coordinates) < 10.0);
});
SMORMS3 smorms3;
smorm3.Optimize(f, coordinates, cb);
// Make sure that we did not get to the optimum.
for (size_t i = 0; i < coordinates.n_elem; ++i)
REQUIRE(std::abs(coordinates[i]) >= 3.0);
}
Anyway, I think that will work.
|
Hopefully these comments are helpful. :) |
| { | ||
| SMORMS3 smorms3; | ||
| EarlyStopCallbacksFunctionTest(smorms3); | ||
| } |
rcurtin
Jul 2, 2020
Member
Ok, here is a rough outline. It is not perfect. In fact it may not even compile. But hopefully it will be enough to get the idea across.
// Maybe the name could be shorter...
TEST_CASE("EarlyStopAtMinLossWithCustomLambdaTest", "[CallbacksTest]")
{
// Use the 50-dimensional Rosenbrock function.
GeneralizedRosenbrockFunction f(50);
// Start at some really large point.
arma::mat coordinates = f.GetInitialPoint();
f.fill(100.0);
EarlyStopAtMinLoss<arma::mat> cb(
[&](const arma::mat& coordinates)
{
// Terminate if any coordinate has a value less than 10.
return arma::any(arma::abs(coordinates) < 10.0);
});
SMORMS3 smorms3;
smorm3.Optimize(f, coordinates, cb);
// Make sure that we did not get to the optimum.
for (size_t i = 0; i < coordinates.n_elem; ++i)
REQUIRE(std::abs(coordinates[i]) >= 3.0);
}
Anyway, I think that will work. 👍
Ok, here is a rough outline. It is not perfect. In fact it may not even compile. But hopefully it will be enough to get the idea across.
// Maybe the name could be shorter...
TEST_CASE("EarlyStopAtMinLossWithCustomLambdaTest", "[CallbacksTest]")
{
// Use the 50-dimensional Rosenbrock function.
GeneralizedRosenbrockFunction f(50);
// Start at some really large point.
arma::mat coordinates = f.GetInitialPoint();
f.fill(100.0);
EarlyStopAtMinLoss<arma::mat> cb(
[&](const arma::mat& coordinates)
{
// Terminate if any coordinate has a value less than 10.
return arma::any(arma::abs(coordinates) < 10.0);
});
SMORMS3 smorms3;
smorm3.Optimize(f, coordinates, cb);
// Make sure that we did not get to the optimum.
for (size_t i = 0; i < coordinates.n_elem; ++i)
REQUIRE(std::abs(coordinates[i]) >= 3.0);
}
Anyway, I think that will work.
Now we can evaluate any function loss directly using a lambda function inside the constructor of early stop at min loss Signed-off-by: Omar Shrit <omar@shrit.me>
Signed-off-by: Omar Shrit <omar@shrit.me>
Signed-off-by: Omar Shrit <omar@shrit.me>
Signed-off-by: Omar Shrit <omar@shrit.me>
| namespace ens { | ||
|
|
||
| /** | ||
| * Early stopping to terminate the optimization process early if the loss stops | ||
| * decreasing. | ||
| */ | ||
| template<typename MatType = arma::mat> |
rcurtin
Jul 7, 2020
Member
Crap, actually, I just realized, to make this change would force a major version bump:
https://github.com/mlpack/ensmallen/blob/master/UPDATING.txt
I think we can fix this by doing the following:
- Rename this class to
EarlyStopAtMinLossType.
- Create a typedef so that
EarlyStopAtMinLoss is EarlyStopAtMinLossType<arma::mat>.
- Update the documentation (I'll provide a suggestion).
- Add a comment here suggesting that we should change
EarlyStopAtMinLossType to EarlyStopAtMinLoss for ensmallen 3.10.0.
Crap, actually, I just realized, to make this change would force a major version bump:
https://github.com/mlpack/ensmallen/blob/master/UPDATING.txt
I think we can fix this by doing the following:
- Rename this class to
EarlyStopAtMinLossType. - Create a typedef so that
EarlyStopAtMinLossisEarlyStopAtMinLossType<arma::mat>. - Update the documentation (I'll provide a suggestion).
- Add a comment here suggesting that we should change
EarlyStopAtMinLossTypetoEarlyStopAtMinLossfor ensmallen 3.10.0.
|
Everything looks great @shrit. I just left a couple final comments (I promise, no more! :)); basically I realized that there are some reverse compatibility issues that we do need to handle. Hopefully my suggestions are helpful---if you want to try it differently go ahead, the ideas I posted are just ideas. :) |
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Signed-off-by: Omar Shrit <omar@shrit.me>
|
Second approval provided automatically after 24 hours. |
|
@rcurtin Perfect, you are welcome, I am very excited about this pull request. |
|
Thanks @shrit! I'll try to do a release with this support tomorrow. |
|
Maybe it makes sense to wait for #149 before we do a new release, from my side this is pretty close to getting merged. |
|
The only thing is, the more releases, the closer I can get us to fully automatic releases. So maybe I can do one today (once I finish debugging the mlpack-bot bit I'm working on now), and then one again a couple days later? |
|
I didn't thought about that, makes sense to me. |
This pull request implement the discussion we had in the issue: mlpack/mlpack#2185
This is only a first prototype, I have tested it on my neural network it seems to be working.
I implemented it in ensmallen, I know that @rcurtin requested to be implemented on mlpack side, but during the implementation on mlpack, I have noticed that I have to modify ensmallen callback functions, So I implemented it in ensmallen.
Still missing callback tests and documentation, but I do not know how to test it without data.