New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support default dtype in F.spatial_transformer_grid
#5114
Conversation
6098259
to
8520bff
Compare
8520bff
to
97fd517
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except one comment.
class TestSpatialTransformerGrid(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self._old_dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using using_config
by calling __enter__
and __exit__
directly?
def setUp(self):
self._config_user = chainer.using_config('dtype', self.dtype)
self._config_user.__enter__()
...
def tearDown(self):
self._config_user.__exit__(None, None, None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I added another comment below; that has higher priority so please look that first!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
@@ -16,14 +16,15 @@ class SpatialTransformerGrid(function.Function): | |||
|
|||
def __init__(self, output_shape): | |||
self.output_shape = output_shape | |||
self._dtype = chainer.get_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I overlooked it; I think it's better to accept theta
of any dtype with kind f
like many other functions do instead of restricting it to the current dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added one comment.
xp.linspace(-1, 1, H, dtype=numpy.float32), | ||
xp.linspace(-1, 1, W, dtype=numpy.float32), indexing='ij', | ||
xp.linspace(-1, 1, H, dtype=self._dtype), | ||
xp.linspace(-1, 1, W, dtype=self._dtype), indexing='ij', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use theta.dtype
instead of self._dtype
? (ditto for other parts, too.) We could remove self._dtype
, too.
@takagi ping? |
Now I touch this. |
Jenkins, test this please. |
Jenkins CI test (for commit f5aa400, target branch master) failed with status FAILURE. |
Jenkins, test this please. |
Jenkins CI test (for commit c882d63, target branch master) failed with status FAILURE. |
jenkins, test this please. |
Jenkins CI test (for commit c882d63, target branch master) failed with status FAILURE. |
LGTM! |
This PR is a part of #4582, makes
F.spatial_transformer_gird
use the default dtype.