-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New style relu #3175
New style relu #3175
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments.
chainer/functions/activation/relu.py
Outdated
self.retain_outputs((0,)) | ||
self._use_cudnn = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's simpler to make _use_cudnn = False
a class attribute and only sets explicitly on cuDNN path.
I mean:
class ReLU(function_node.FunctionNode):
_use_cudnn = False
...
def forward_gpu(self, x):
if ...:
self._use_cudnn = True
...
...
chainer/functions/activation/relu.py
Outdated
return ReLUGrad2().apply((y, gy[0])) | ||
|
||
|
||
class Zero(function_node.FunctionNode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it needed?
chainer/functions/activation/relu.py
Outdated
return Zero().apply(gy) | ||
|
||
|
||
class Heaviside(function_node.FunctionNode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following code simplifies it.
def heaviside(x):
return utils.force_array((x.data > 0).astype(x.dtype))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I added some more comments.
chainer/functions/activation/relu.py
Outdated
def backward(self, indexes, gy): | ||
ret = [] | ||
if 0 in indexes: | ||
ret.append(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove the first argument from inputs and instead pass it as an argument of __init__
? (It will simplifies the backprop, which is good for performance)
chainer/functions/activation/relu.py
Outdated
def backward(self, indexes, gy): | ||
ret = [] | ||
if 0 in indexes: | ||
ret.append(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
chainer/functions/activation/relu.py
Outdated
if 0 in indexes: | ||
ret.append(None) | ||
if 1 in indexes: | ||
ret.append(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
chainer/functions/activation/relu.py
Outdated
@@ -12,48 +12,103 @@ | |||
_mode = cudnn.cudnn.CUDNN_ACTIVATION_RELU | |||
|
|||
|
|||
class ReLU(function.Function): | |||
class ReLU(function_node.FunctionNode): | |||
|
|||
"""Rectified Linear Unit.""" | |||
# TODO(beam2d): Implement in-place version. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not directly releated to this PR, but I found this TODO comment is obsolete. Can you remove it?
As #3096 is merged to the master branch, I rebased the PR. |
Thank you for your comments. I updated the PR. Note that although I wrote the docstrings of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some more comments.
Note: I noted that "you do not need to check if indexes
is empty", but I found that Variable.backward()
does not check it correctly. I'll fix this point in another PR, so it's ok to proceed with removing the check.
chainer/functions/activation/relu.py
Outdated
return gx, | ||
|
||
def backward(self, indexes, gy): | ||
if 0 in indexes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You do not need to check this; indexes
is always non-empty (otherwise backward
is not called).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I have removed it.
chainer/functions/activation/relu.py
Outdated
return cudnn.activation_backward(a, b, inputs[0], _mode), | ||
|
||
def backward(self, indexes, gy): | ||
if 0 in indexes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the check (see the above comment)
chainer/functions/activation/relu.py
Outdated
|
||
|
||
def _heaviside(x): | ||
return utils.force_array((x > 0).astype(x.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this force_array
needed?
chainer/functions/activation/relu.py
Outdated
def backward(self, indexes, gy): | ||
if 0 in indexes: | ||
xp = cuda.get_array_module(gy[0]) | ||
b = xp.asarray(self.b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this asarray
needed?
Thank you. Updated. |
Please resolve the flake8 errors. |
Jenkins, test this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more comment
chainer/functions/activation/relu.py
Outdated
""" | ||
|
||
def __init__(self, b): | ||
super(ReLUGrad2).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
, self
is missing (or remove this line; it is allowed to not call super init in FunctionNode)
LGTM |
Thank you! |
This PR implements new-style version of
F.relu
.This PR depends on #3096, and is a part of #3147.