New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Flexible ReLU #1341
Implement Flexible ReLU #1341
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* author = {Suo Qiu, Xiangmin Xu and Bolun Cai}, | ||
* title = {FReLU: Flexible Rectified Linear Units for Improving | ||
* Convolutional Neural Networks} | ||
* journal = {arxiv preprint}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, can you add the URL too?
@Manthan-R-Sheth I have added very minor changes to this code. Can you please create a pull request from my repo? |
@Manthan-R-Sheth I think this is complete. Is it? |
* For more information, read the following paper: | ||
* | ||
* @code | ||
* @article{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is missing a name, also do you mind to move this into the class description block, that way, doxygen can pick this up.
* title = {FReLU: Flexible Rectified Linear Units for Improving | ||
* Convolutional Neural Networks} | ||
* journal = {arxiv preprint}, | ||
* URL = {https://arxiv.org/pdf/1706.08098.pdf}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you link to the arxiv page instead using the pdf link here.
* | ||
*@tparam OutputDataType Type of the output data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube) | ||
* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we could remove the extra line here.
* | ||
*@tparam InputDataType Type of the input data ( arma::colvec, arma::mar, | ||
* arma::sp_mat or arma::cube) | ||
* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind to remove the extra line and the extra space before arma::colvec
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking over this PR; I'm glad to see it updated. Since alpha
should be trainable, shouldn't we also have a Gradient()
function and Parameters()
? I also left a few comments about efficiency.
const InputType&& input, OutputType&& output) | ||
{ | ||
output = arma::max(arma::zeros<InputType>(input.n_rows, input.n_cols), input) | ||
+ alpha; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor style issue---the second line here should be doubly indented (four spaces, not two). Also I think you might be able to do this with a lambda passed to .transform()
, which could make it a little faster.
DataType derivative; | ||
|
||
//! Compute the first derivative of FlexibleReLU function. | ||
derivative = input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can save a copy here if we do derivative.set_size(input.n_rows, input.n_cols)
, which will just set the size.
for (size_t i = 0; i < input.n_elem; i++) | ||
{ | ||
derivative(i) = input(i) > 0? 1 : 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also replace this with a call to .transform()
.
@rcurtin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, even though Gradient()
is simple I think it is still a good idea to add a CheckGradient()
test. I had a few other comments, sorry if they contradict what I wrote earlier about transform()
.
} | ||
|
||
arma::mat zeros = arma::zeros<arma::Mat<eT>>(input.n_rows, input.n_cols); | ||
gradient(0) = arma::accu(error % arma::min(zeros, input)) / input.n_cols; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be a lot better here if we could avoid allocating zeros
; that will be time-consuming.
int i = -1; | ||
output = arma::zeros<InputType>(input.n_rows, input.n_cols); | ||
output.transform([input, &i, this](double val) { ++i; | ||
return (std::max(input(i), 0.0) + alpha(0)); } ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, I did not realize that transform()
is in-place. Why not use something like arma::clamp()
instead then? i.e. output = arma::clamp(input, 0.0, DBL_MAX) + alpha(0);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we just needed a function like that
derivative.set_size(input.n_rows, input.n_cols); | ||
int i = -1; | ||
derivative.transform([input, &i](double val) { ++i; | ||
return (input(i) > 0? 1 : 0); } ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think we could just use a boolean expression: derivation = (input > 0);
(or some variant like that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have used this :
derivative = arma::sign(input);
derivative.elem(arma::find(derivative < 0.0)) += 1;
This should be good to go or the lambda was faster already?
@rcurtin |
dd5c8c8
to
00e2c9f
Compare
Looks like the gradient test is failing; can you look into it please?
|
I don't think that's the case here, have to take a closer look at the code, especially the |
@zoq @rcurtin
and I found that the errors after |
DataType derivative; | ||
//! Compute the first derivative of FlexibleReLU function. | ||
derivative = arma::sign(input); | ||
derivative.elem(arma::find(derivative < 0.0)) += 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do something like: derivative.elem(find(input > 0) ).ones();
here, I think it's easier to get. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think i will have to initialize derivative
with zeros()
or input
before using derivative.elem(find(input > 0) ).ones();
.
So isn't present implementation faster than initializing the matrix and then operating on it?
Tell me if i missed anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be possible to also do derivative = arma::clamp(arma::sign(input), 0.0, 1.0);
but I don't think this will currently give any acceleration. It may be slightly easier to read though.
{ | ||
gradient = arma::zeros<arma::Mat<eT>>(1, 1); | ||
} | ||
gradient(0) = arma::accu(error) / input.n_cols; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be gradient(0) = arma::sum(error)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum()
would give a row vector (error for each instance in input), but we need a double values for alpha
(trainable parameter) and so i used accu()
.
it is accu(error)
as the derivative of the output with respect to alpha
is 1 in case of flexible relu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right, will run some tests later today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure @zoq thanks.
{ | ||
if (gradient.n_elem == 0) | ||
{ | ||
gradient = arma::zeros<arma::Mat<eT>>(1, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use set_size
here, to make this step somewhat faster.
Sorry for the slow response on this one, the main issue is that the output is clipped (zeroed if < 0), so a really small perturbation in the positive/negative direction ends up with the same result and has no effect. An easy solution is to rely on positive weights, which isn't perfect since we don't cover the complete range of the frelu function. So, in addition, we could compare the gradient against a precomputed one with positive/negative weights. Let me know what you think. Here is the modified test: /**
* Flexible ReLU layer numerically gradient test.
*/
BOOST_AUTO_TEST_CASE(GradientFlexibleReLULayerTest)
{
// Add function gradient instantiation.
struct GradientFunction
{
GradientFunction()
{
input = arma::randu(2, 1);
target = arma::mat("1");
model = new FFN<NegativeLogLikelihood<>, RandomInitialization>(
NegativeLogLikelihood<>(), RandomInitialization(0.1, 0.5));
model->Predictors() = input;
model->Responses() = target;
model->Add<LinearNoBias<> >(2, 5);
model->Add<FlexibleReLU<> >(0.05);
model->Add<LogSoftMax<> >();
}
~GradientFunction()
{
delete model;
}
double Gradient(arma::mat& gradient) const
{
arma::mat output;
double error = model->Evaluate(model->Parameters(), 0, 1);
model->Gradient(model->Parameters(), 0, gradient, 1);
return error;
}
arma::mat& Parameters() { return model->Parameters(); }
FFN<NegativeLogLikelihood<>, RandomInitialization>* model;
arma::mat input, target;
} function;
BOOST_REQUIRE_LE(CheckGradient(function), 1e-4);
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into all the issues; I'll go ahead and merge this in 3 days to leave time for any other comments and I will fix some remaining minor style issues afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me; I'm looking forward to seeing this merged. Thanks for the contribution. :)
DataType derivative; | ||
//! Compute the first derivative of FlexibleReLU function. | ||
derivative = arma::sign(input); | ||
derivative.elem(arma::find(derivative < 0.0)) += 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be possible to also do derivative = arma::clamp(arma::sign(input), 0.0, 1.0);
but I don't think this will currently give any acceleration. It may be slightly easier to read though.
@Manthan-R-Sheth thanks for the great contribution. |
This is in continuity to #1281.
Function call overheads have been removed and code is simplified to include the calculation in
Forward()
andBackward()
.