New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSoC] Adding optimization features not related with my GSoC project #1070

Merged
merged 10 commits into from Jul 29, 2017

Conversation

Projects
None yet
3 participants
@17minutes
Contributor

17minutes commented Jul 21, 2017

This PR is part of my GSoC project "Augmented RNNs".

Implemented:

  • CrossEntropyLayer for evaluating the performance of the model on binary vector targets.
  • Initial version of the gradient clipping (albeit very dirty, as @zoq and @rcurtin have already mentioned in #1005)

As far as I understand, the conversation related to these two points (including but not limited to the reusable update API for gradient clipping) is transferred here.

Show outdated Hide outdated src/mlpack/core/optimizers/sgd/update_policies/gradient_clipping.hpp
Show outdated Hide outdated src/mlpack/core/optimizers/sgd/update_policies/gradient_clipping.hpp
*/
void Update(arma::mat& iterate,
const double stepSize,
const arma::mat& gradient)

This comment has been minimized.

@zoq

zoq Jul 23, 2017

Member

gradient.transformshould fail, if we pass the gradient is const.

@zoq

zoq Jul 23, 2017

Member

gradient.transformshould fail, if we pass the gradient is const.

This comment has been minimized.

@rcurtin

rcurtin Jul 25, 2017

Member

Another option is to relax the const restriction and allow the UpdatePolicy to modify the gradient when Update() is called. I don't think it would be a problem to do that, as long as we document that the UpdatePolicy is allowed to do so. And it would avoid the copy too. :)

@rcurtin

rcurtin Jul 25, 2017

Member

Another option is to relax the const restriction and allow the UpdatePolicy to modify the gradient when Update() is called. I don't think it would be a problem to do that, as long as we document that the UpdatePolicy is allowed to do so. And it would avoid the copy too. :)

This comment has been minimized.

@zoq

zoq Jul 25, 2017

Member

Agreed, that is another good option, @partobs-mdp what do you think?

@zoq

zoq Jul 25, 2017

Member

Agreed, that is another good option, @partobs-mdp what do you think?

This comment has been minimized.

@17minutes

17minutes Jul 25, 2017

Contributor

Sure, but should we do this? As we have seen, the clamp performance is not much of an issue, and (imho) we shouldn't break the natural assumption that the update policy doesn't break the variable which stores the gradient.

@17minutes

17minutes Jul 25, 2017

Contributor

Sure, but should we do this? As we have seen, the clamp performance is not much of an issue, and (imho) we shouldn't break the natural assumption that the update policy doesn't break the variable which stores the gradient.

This comment has been minimized.

@zoq

zoq Jul 25, 2017

Member

Right, the performance difference is insignificant, we could use arma::clamp here, no need to change it if you don't think that's a good idea. We just wanted to point out there is another solution.

@zoq

zoq Jul 25, 2017

Member

Right, the performance difference is insignificant, we could use arma::clamp here, no need to change it if you don't think that's a good idea. We just wanted to point out there is another solution.

Show outdated Hide outdated src/mlpack/tests/sgd_test.cpp
Show outdated Hide outdated src/mlpack/tests/sgd_test.cpp
Show outdated Hide outdated src/mlpack/tests/sgd_test.cpp
Show outdated Hide outdated src/mlpack/methods/ann/layer/cross_entropy_error_impl.hpp
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class CrossEntropyError

This comment has been minimized.

@zoq

zoq Jul 23, 2017

Member

Can we add a simple test for the cross entropy error function?

@zoq

zoq Jul 23, 2017

Member

Can we add a simple test for the cross entropy error function?

@rcurtin

Looks good to me so far, just some minor comments to address from my end. Overall I think the design is fine, but we should definitely add a test for the CrossEntropyLayer like Marcus suggested.

Show outdated Hide outdated src/mlpack/core/optimizers/sgd/sgd.hpp
*/
void Update(arma::mat& iterate,
const double stepSize,
const arma::mat& gradient)

This comment has been minimized.

@rcurtin

rcurtin Jul 25, 2017

Member

Another option is to relax the const restriction and allow the UpdatePolicy to modify the gradient when Update() is called. I don't think it would be a problem to do that, as long as we document that the UpdatePolicy is allowed to do so. And it would avoid the copy too. :)

@rcurtin

rcurtin Jul 25, 2017

Member

Another option is to relax the const restriction and allow the UpdatePolicy to modify the gradient when Update() is called. I don't think it would be a problem to do that, as long as we document that the UpdatePolicy is allowed to do so. And it would avoid the copy too. :)

Show outdated Hide outdated src/mlpack/methods/ann/layer/cross_entropy_error.hpp
const arma::Mat<eT>&& input, const arma::Mat<eT>&& target)
{
return -arma::accu(target % arma::log(input + eps) +
(1. - target) % arma::log(1. - input + eps));

This comment has been minimized.

@rcurtin

rcurtin Jul 25, 2017

Member

It's a little late so I'm not sure I'm thinking 100% clearly, but this appears that it will work in the multiclass setting as long as the input and target matrices are one-hot encoded. So labels like [0 2 1] will not work but labels like [[1 0 0] [0 0 1] [0 1 0]] will. Correct me if I am wrong. (i.e., I think this is right and works the way I would expect it to.)

@rcurtin

rcurtin Jul 25, 2017

Member

It's a little late so I'm not sure I'm thinking 100% clearly, but this appears that it will work in the multiclass setting as long as the input and target matrices are one-hot encoded. So labels like [0 2 1] will not work but labels like [[1 0 0] [0 0 1] [0 1 0]] will. Correct me if I am wrong. (i.e., I think this is right and works the way I would expect it to.)

This comment has been minimized.

@17minutes

17minutes Jul 25, 2017

Contributor

Well, it won't work as it stands, but the computations in the case you've mentioned would be easier: -arma::accu(target % arma::log(input + eps)) (the formula get easier due to data representation redundancy)

@17minutes

17minutes Jul 25, 2017

Contributor

Well, it won't work as it stands, but the computations in the case you've mentioned would be easier: -arma::accu(target % arma::log(input + eps)) (the formula get easier due to data representation redundancy)

This comment has been minimized.

@rcurtin

rcurtin Jul 27, 2017

Member

Fair enough; in this case, can you clarify the limitation on how the labels should be in the documentation for the class? We should definitely support multiclass cross-entropy at some point, so if you don't want to do that here that's ok, but in that case could I ask you to open a new issue for it, detailing basically what needs to be done and where someone could look to get started with it?

@rcurtin

rcurtin Jul 27, 2017

Member

Fair enough; in this case, can you clarify the limitation on how the labels should be in the documentation for the class? We should definitely support multiclass cross-entropy at some point, so if you don't want to do that here that's ok, but in that case could I ask you to open a new issue for it, detailing basically what needs to be done and where someone could look to get started with it?

17minutes added some commits Jul 25, 2017

Show outdated Hide outdated src/mlpack/core/optimizers/sgd/sgd.hpp
*/
void Update(arma::mat& iterate,
const double stepSize,
const arma::mat& gradient)

This comment has been minimized.

@zoq

zoq Jul 25, 2017

Member

Agreed, that is another good option, @partobs-mdp what do you think?

@zoq

zoq Jul 25, 2017

Member

Agreed, that is another good option, @partobs-mdp what do you think?

Show outdated Hide outdated src/mlpack/methods/ann/layer/cross_entropy_error_impl.hpp
const arma::Mat<eT>&& target,
arma::Mat<eT>&& output)
{
output = (1. - target) / (1. - input + eps) - target / (input + eps);

This comment has been minimized.

@zoq

zoq Jul 25, 2017

Member

Not sure, if the compile would optimize that away, but we could rewrite the expression as:

output = (target - input) / ((x - 1) % x); and save one extra division and the eps addition.

@zoq

zoq Jul 25, 2017

Member

Not sure, if the compile would optimize that away, but we could rewrite the expression as:

output = (target - input) / ((x - 1) % x); and save one extra division and the eps addition.

This comment has been minimized.

@17minutes

17minutes Jul 25, 2017

Contributor

Checked it on a piece of paper and got the precise expression we should get after simplifying the original expression so that it would still be the gradient of the loss function with log(x + eps):

output = (input - target + eps * (1. - 2 * target)) / ((1. - input + eps) % (input + eps))

This one, however, doesn't really optimize much (or, at least, I think so), because even though it runs only one (element-wise) division, it runs three multiplications, which are also slow (as compared with additions - that's why I didn't count them).

@17minutes

17minutes Jul 25, 2017

Contributor

Checked it on a piece of paper and got the precise expression we should get after simplifying the original expression so that it would still be the gradient of the loss function with log(x + eps):

output = (input - target + eps * (1. - 2 * target)) / ((1. - input + eps) % (input + eps))

This one, however, doesn't really optimize much (or, at least, I think so), because even though it runs only one (element-wise) division, it runs three multiplications, which are also slow (as compared with additions - that's why I didn't count them).

This comment has been minimized.

@zoq

zoq Jul 25, 2017

Member

Right, I guess the only benefit I see is that we could avoid adding eps since: ((x - 1) % x) should be stable. Anyway, I don't mind to leave it as it is.

@zoq

zoq Jul 25, 2017

Member

Right, I guess the only benefit I see is that we could avoid adding eps since: ((x - 1) % x) should be stable. Anyway, I don't mind to leave it as it is.

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Jul 25, 2017

Contributor

What's wrong with the Internet connection on Travis CI? It didn't even manage to install boost from apt :'(

Contributor

17minutes commented Jul 25, 2017

What's wrong with the Internet connection on Travis CI? It didn't even manage to install boost from apt :'(

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jul 25, 2017

Member

Let me restart the build.

Member

zoq commented Jul 25, 2017

Let me restart the build.

@17minutes

This comment has been minimized.

Show comment
Hide comment
@17minutes

17minutes Jul 26, 2017

Contributor

Is there anything else that should be done on this PR on my side?

Contributor

17minutes commented Jul 26, 2017

Is there anything else that should be done on this PR on my side?

@zoq

zoq approved these changes Jul 26, 2017

Looks ready for me, I'll wait 3 days for the merge, in case anyone has any more comments.

@rcurtin

Code looks good, nothing more from my side. Thanks for splitting this out from the other PR so we can merge it more quickly. :)

@zoq zoq merged commit 5c68061 into mlpack:master Jul 29, 2017

3 checks passed

Style Checks Build finished.
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jul 29, 2017

Member

Thanks for the contributions!

Member

zoq commented Jul 29, 2017

Thanks for the contributions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment