Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping at loss min on validation set #165

Merged
merged 47 commits into from Jul 14, 2020
Merged

Conversation

@shrit
Copy link
Member

@shrit shrit commented Feb 10, 2020

This pull request implement the discussion we had in the issue: mlpack/mlpack#2185
This is only a first prototype, I have tested it on my neural network it seems to be working.
I implemented it in ensmallen, I know that @rcurtin requested to be implemented on mlpack side, but during the implementation on mlpack, I have noticed that I have to modify ensmallen callback functions, So I implemented it in ensmallen.
Still missing callback tests and documentation, but I do not know how to test it without data.

* Early stopping to terminate the optimization process early if the loss stops
* decreasing.
*/
class EarlyStopAtMinLossOnValidation

This comment has been minimized.

@zoq

zoq Feb 17, 2020
Member

I guess instead of writing another callback class, we could also add another constructor to the existing EarlyStopAtMinLoss class, let me know what you think, might avoid some code duplication.

This comment has been minimized.

@shrit

shrit Feb 17, 2020
Author Member

I agree I prefer one class with two constructors, the only reason for this prototype is our last discussion on the related issue. I will merge the two classes and push the code.

This comment has been minimized.

@zoq

zoq Feb 18, 2020
Member

Right, I guess that changed a bit because it's now part of ensmallen.

@shrit shrit requested a review from zoq Feb 18, 2020
@shrit
Copy link
Member Author

@shrit shrit commented Feb 18, 2020

I only overloaded the Evalute() function in sgd tests to pass the test, I did not implement it yet. I have tested it on real data set, it seems to be working

@shrit shrit requested a review from zoq Feb 19, 2020
{
return static_cast<FunctionType*>(
static_cast<Function<FunctionType,
MatType, GradType>*>(this))->Evaluate(coordinates);

This comment has been minimized.

@zoq

zoq Mar 2, 2020
Member

I think this is one of the reasons why @rcurtin proposed to implement this in mlpack. I'm wondering if we can put this into the callback itself, instead of extending the Function Type interface, which we like to keep to a minimum.

This comment has been minimized.

@shrit

shrit Mar 2, 2020
Author Member

@zoq I agree. However, I tried to implement it first in mlpack, but in mlpack case, we will have to change the architecture of this callback.
I think it is the same, even if we are in mlpack we will have to add_evaluate in ensmallen. If we had a close look, it seems to me that EndEpoch() requires FunctionType& function in our case, which is ens::Function. Suppose ens::Function</*mlpack neural network model should goes here*/>::Evaluate(arma::mat, arma::mat) interface does not exist. Can you give me more details about possible implementation in mlpack?

Otherwise, I can try to move the overloaded interface in the callbacks.

This comment has been minimized.

@shrit

shrit Mar 25, 2020
Author Member

@rcurtin Do you have any thoughts on this, any starting guide from mlpack?

This comment has been minimized.

@rcurtin

rcurtin Apr 2, 2020
Member

I think that maybe the best way to put this together would be to implement a callback that has structure a little bit like this:

class EarlyStoppingOnValidationSet // maybe can be named better?
{
 private:
  FFN<>& network;
  arma::mat& validationPredictors;
  arma::mat& validationResponses;
  double lastLoss; // initialize to DBL_MAX?

 public:
  template<typename OptimizerType, typename FunctionType, typename MatType>
  bool EndEpoch(OptimizerType& /* optimizer */,
                FunctionType& /* function */,
                const MatType& /* coordinates */,
                const size_t /* epoch */,
                const double /* objective */)
  {
    // The FFN we are holding internally should be the same one as what is being optimized.
    const double validationLoss = ffn->Evaluate(validationPredictors, validationResponses);
    if (validationLoss > lastLoss)
       return true; // or do we return false to terminate early?  I can't remember

    return false;
  }
};

I've omitted some functions there, but maybe that's a decent enough idea? It would definitely have to be in the mlpack codebase since it depends on FFN. Let me know what you think---and sorry it took so long to get to this. 👍

This comment has been minimized.

@shrit

shrit Apr 2, 2020
Author Member

Now, I understand the idea related to the FFN reference. I will implement it and see if it works.
Thanks. 👍

@mlpack-bot
Copy link

@mlpack-bot mlpack-bot bot commented May 3, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label May 3, 2020
@shrit
Copy link
Member Author

@shrit shrit commented May 3, 2020

Keep open

@mlpack-bot mlpack-bot bot removed the s: stale label May 3, 2020
@zoq zoq added the s: keep open label May 5, 2020
@shrit
Copy link
Member Author

@shrit shrit commented May 5, 2020

@zoq can you have a look on this please. Thanks

* @param patienceIn The number of epochs to wait after the minimum loss has
* been reached or no improvement has been made (Default: 10).
*/
EarlyStopAtMinLoss<AnnType>(AnnType& network,

This comment has been minimized.

@zoq

zoq May 7, 2020
Member

I guess since this can be used with other methods as well, we should rename the parameter, what about function?

This comment has been minimized.

@rcurtin

rcurtin May 7, 2020
Member

Right, seems like maybe this would be a perfect use case for a std::function? We could be really general about it. The user can just write a lambda that, when called, evaluates the loss on the given parameters...

// assume we built this network somehow and we want to optimize it...
FFN<...> network;
// we have some test set and responses...
arma::mat testSet;
arma::mat testResponses;

EarlyStopAtMinLoss e(
    [network, testSet, testResponses](const arma::mat& /* parameters */) -> double
    {
      return network.Evaluate(testSet, testResponses);
    });

network.Train(trainSet, trainResponses, e, ...);

That's just a sketch, but that's sufficiently general that the user can do whatever they want with it, with the easiest example being using a validation/test set to compute a loss. I guess, from my end, I want to make sure that the callback isn't specific to the machine learning setting (since ensmallen is more general, it's just generic optimization).

This comment has been minimized.

@zoq

zoq May 7, 2020
Member

I like the idea.

This comment has been minimized.

@rcurtin

rcurtin May 26, 2020
Member

@shrit what do you think, would it be better to take any std::function here instead of requiring something that has an Evaluate() method?

This comment has been minimized.

@shrit

shrit May 28, 2020
Author Member

@rcurtin Maybe I can add another constructor that takes an std::function, at the same time the user might do anything in std::function and not only Evaluate
I am not sure I can replace everything by std::function, since the network.Evaluate need to be called at the end of each epoch. right? I am not sure

This comment has been minimized.

@rcurtin

rcurtin May 28, 2020
Member

Yeah, adding another constructor could be a nice thing to do, so that the technique is more general. I guess, my hope is not to link the mlpack and ensmallen codebases too closely here, so that people can use this callback in non-mlpack situations. If the user passes a lambda too, that actually makes the implementation easier, since EarlyStopAtMinLoss doesn't have to then track the validation set and responses (it will be implicitly held in the given lambda/std::function).

void EarlyStopCallbacksFunctionTest(OptimizerType& optimizer)
{
arma::mat data, testData, shuffledData;
arma::Row<size_t> responses, testResponses, shuffledResponses;

LogisticRegressionTestData(data, testData, shuffledData,
responses, testResponses, shuffledResponses);

LogisticRegression<> lr(shuffledData, shuffledResponses, 0.5);
arma::mat coordinates = lr.GetInitialPoint();
LogisticRegressionValidationFunction lrValidation(lr, coordinates);

EarlyStopAtMinLoss<LogisticRegressionValidationFunction,
arma::mat, arma::Row<size_t>> cb(
lrValidation, testData, testResponses);

optimizer.Optimize(lr, coordinates, cb);
}

TEST_CASE("EarlyStopAtMinLossCallbackTest", "[CallbacksTest]")
{
SMORMS3 smorms3;
Comment on lines 156 to 179

This comment has been minimized.

@shrit

shrit May 9, 2020
Author Member

@zoq I have added the function as you have proposed, but I am not sure that we are stopping the optimization as we were doing before

// SGDTestFunction f;
// arma::mat coordinates = f.GetInitialPoint();

// // Instantiate the optimizer with a number of iterations that will take a
// // long time to finish.
// StandardSGD s(0.0003, 1, 2000000000, -10);
// s.ExactObjective() = true;

// // The optimization process should return in one second.
// const double result = s.Optimize(f, coordinates, EarlyStopAtMinLoss(100));

// REQUIRE(result == Approx(-1.0).epsilon(0.0005));
// REQUIRE(coordinates(0) == Approx(0.0).margin(1e-3));
// REQUIRE(coordinates(1) == Approx(0.0).margin(1e-7));
// REQUIRE(coordinates(2) == Approx(0.0).margin(1e-7));
Comment on lines 501 to 515

This comment has been minimized.

@shrit

shrit May 9, 2020
Author Member

I have commented this one, if it is no longer required I will remove it

This comment has been minimized.

@rcurtin

rcurtin Jun 30, 2020
Member

I think that we need to keep this test---I don't see a reason to remove it. Maybe I overlooked something?

@rcurtin
Copy link
Member

@rcurtin rcurtin commented Jun 22, 2020

@shrit can I ask what the status of this is? When I look at the diff, it seems like the code is specific to mlpack neural networks. The constructor you added is nice and I think that is the best interface we can give, but a couple comments:

  • I think now we can remove the AnnType& from the template parameters.
  • The constructor should probably accept a std::function<double(const MatType&)> (I think I am writing that signature right); basically we should be able to call the std::function with the given coordinates (of type const MatType&) and receive back a double.
  • Inside of the class, I think all we need to do is hold that std::function<>.
  • By default, we can have behavior such that if the std::function<> was not specified, we simply use function.Evaluate() like the current code does. (Maybe we need a bool member to specify whether or not we have a std::function.)

Let me know what you think; I hope the comments here are helpful. It would be really great to incorporate this into a new ensmallen release. :)

@shrit
Copy link
Member Author

@shrit shrit commented Jun 29, 2020

@rcurtin I have applied the changes, your comments were very helpful, Thanks.
I have also updated the tests. However, I think they need to be checked a little bit.

Copy link
Member

@rcurtin rcurtin left a comment

Hey there @shrit, thanks for continuing the work on this one. 👍 I think we're getting close (at least from my perspective). Let me know what you think of my comments.

include/ensmallen_bits/function/add_evaluate.hpp Outdated Show resolved Hide resolved
// SGDTestFunction f;
// arma::mat coordinates = f.GetInitialPoint();

// // Instantiate the optimizer with a number of iterations that will take a
// // long time to finish.
// StandardSGD s(0.0003, 1, 2000000000, -10);
// s.ExactObjective() = true;

// // The optimization process should return in one second.
// const double result = s.Optimize(f, coordinates, EarlyStopAtMinLoss(100));

// REQUIRE(result == Approx(-1.0).epsilon(0.0005));
// REQUIRE(coordinates(0) == Approx(0.0).margin(1e-3));
// REQUIRE(coordinates(1) == Approx(0.0).margin(1e-7));
// REQUIRE(coordinates(2) == Approx(0.0).margin(1e-7));

This comment has been minimized.

@rcurtin

rcurtin Jun 30, 2020
Member

I think that we need to keep this test---I don't see a reason to remove it. Maybe I overlooked something?

{
SMORMS3 smorms3;
EarlyStopCallbacksFunctionTest(smorms3);
}

This comment has been minimized.

@rcurtin

rcurtin Jun 30, 2020
Member

It would probably be good if we could also add a test that makes sure that the given lambda will terminate the optimization. I would suggest this... use, e.g., the Rosenbrock function with a custom lambda callback that terminates whenever any of the coordinates has an absolute value of less than 10. Then once the optimization terminates, make sure that we are nowhere near the minimum, and that at least one of the coordinates has an absolute value of less than 10. Let me know if I can clarify that idea---it's not the only possible idea, of course.

This comment has been minimized.

@shrit

shrit Jul 1, 2020
Author Member

@rcurtin if you have a code sample of the idea, it would be great, I am truly not expert in tests 👍.

This comment has been minimized.

@rcurtin

rcurtin Jul 2, 2020
Member

Ok, here is a rough outline. It is not perfect. In fact it may not even compile. But hopefully it will be enough to get the idea across.

// Maybe the name could be shorter...
TEST_CASE("EarlyStopAtMinLossWithCustomLambdaTest", "[CallbacksTest]")
{
  // Use the 50-dimensional Rosenbrock function.
  GeneralizedRosenbrockFunction f(50);
  // Start at some really large point.
  arma::mat coordinates = f.GetInitialPoint();
  f.fill(100.0);

  EarlyStopAtMinLoss<arma::mat> cb(
      [&](const arma::mat& coordinates)
      {
        // Terminate if any coordinate has a value less than 10.
        return arma::any(arma::abs(coordinates) < 10.0);
      });

  SMORMS3 smorms3;
  smorm3.Optimize(f, coordinates, cb);

  // Make sure that we did not get to the optimum.
  for (size_t i = 0; i < coordinates.n_elem; ++i)
    REQUIRE(std::abs(coordinates[i]) >= 3.0);
}

Anyway, I think that will work. 👍

Copy link
Member

@rcurtin rcurtin left a comment

Hopefully these comments are helpful. :)

{
SMORMS3 smorms3;
EarlyStopCallbacksFunctionTest(smorms3);
}

This comment has been minimized.

@rcurtin

rcurtin Jul 2, 2020
Member

Ok, here is a rough outline. It is not perfect. In fact it may not even compile. But hopefully it will be enough to get the idea across.

// Maybe the name could be shorter...
TEST_CASE("EarlyStopAtMinLossWithCustomLambdaTest", "[CallbacksTest]")
{
  // Use the 50-dimensional Rosenbrock function.
  GeneralizedRosenbrockFunction f(50);
  // Start at some really large point.
  arma::mat coordinates = f.GetInitialPoint();
  f.fill(100.0);

  EarlyStopAtMinLoss<arma::mat> cb(
      [&](const arma::mat& coordinates)
      {
        // Terminate if any coordinate has a value less than 10.
        return arma::any(arma::abs(coordinates) < 10.0);
      });

  SMORMS3 smorms3;
  smorm3.Optimize(f, coordinates, cb);

  // Make sure that we did not get to the optimum.
  for (size_t i = 0; i < coordinates.n_elem; ++i)
    REQUIRE(std::abs(coordinates[i]) >= 3.0);
}

Anyway, I think that will work. 👍

shrit and others added 5 commits Jul 3, 2020
Co-authored-by: Marcus Edel <marcus.edel@fu-berlin.de>
Now we can evaluate any function loss directly using a lambda
function inside the constructor of early stop at min loss

Signed-off-by: Omar Shrit <omar@shrit.me>
Signed-off-by: Omar Shrit <omar@shrit.me>
@shrit shrit force-pushed the shrit:early_stopping branch from ce33583 to ea1c4e4 Jul 4, 2020
doc/callbacks.md Outdated Show resolved Hide resolved
namespace ens {

/**
* Early stopping to terminate the optimization process early if the loss stops
* decreasing.
*/
template<typename MatType = arma::mat>

This comment has been minimized.

@rcurtin

rcurtin Jul 7, 2020
Member

Crap, actually, I just realized, to make this change would force a major version bump:

https://github.com/mlpack/ensmallen/blob/master/UPDATING.txt

I think we can fix this by doing the following:

  • Rename this class to EarlyStopAtMinLossType.
  • Create a typedef so that EarlyStopAtMinLoss is EarlyStopAtMinLossType<arma::mat>.
  • Update the documentation (I'll provide a suggestion).
  • Add a comment here suggesting that we should change EarlyStopAtMinLossType to EarlyStopAtMinLoss for ensmallen 3.10.0.
doc/callbacks.md Outdated Show resolved Hide resolved
@rcurtin
Copy link
Member

@rcurtin rcurtin commented Jul 7, 2020

Everything looks great @shrit. I just left a couple final comments (I promise, no more! :)); basically I realized that there are some reverse compatibility issues that we do need to handle. Hopefully my suggestions are helpful---if you want to try it differently go ahead, the ideas I posted are just ideas. :)

tests/callbacks_test.cpp Outdated Show resolved Hide resolved
shrit and others added 4 commits Jul 7, 2020
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Co-authored-by: Ryan Curtin <ryan@ratml.org>
Signed-off-by: Omar Shrit <omar@shrit.me>
@rcurtin
rcurtin approved these changes Jul 8, 2020
Copy link
Member

@rcurtin rcurtin left a comment

Awesome, this looks good to me now. Thank you! 👍 I'll leave this for a few days for any other comments, in case @zoq or anyone has any comments. :)

@mlpack-bot
mlpack-bot bot approved these changes Jul 9, 2020
Copy link

@mlpack-bot mlpack-bot bot left a comment

Second approval provided automatically after 24 hours. 👍

@mlpack-bot mlpack-bot bot removed the s: needs review label Jul 9, 2020
@shrit
Copy link
Member Author

@shrit shrit commented Jul 10, 2020

@rcurtin Perfect, you are welcome, I am very excited about this pull request.

@rcurtin
Copy link
Member

@rcurtin rcurtin commented Jul 14, 2020

Thanks @shrit! I'll try to do a release with this support tomorrow. 👍

@rcurtin rcurtin merged commit a63ea6d into mlpack:master Jul 14, 2020
2 of 4 checks passed
2 of 4 checks passed
Memory Checks Build finished.
Details
Static Code Analysis Checks Build finished.
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@shrit shrit deleted the shrit:early_stopping branch Jul 14, 2020
@zoq
Copy link
Member

@zoq zoq commented Jul 14, 2020

Maybe it makes sense to wait for #149 before we do a new release, from my side this is pretty close to getting merged.

@rcurtin
Copy link
Member

@rcurtin rcurtin commented Jul 14, 2020

The only thing is, the more releases, the closer I can get us to fully automatic releases. So maybe I can do one today (once I finish debugging the mlpack-bot bit I'm working on now), and then one again a couple days later?

@zoq
Copy link
Member

@zoq zoq commented Jul 14, 2020

I didn't thought about that, makes sense to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

4 participants
You can’t perform that action at this time.