New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relax int type restriction #3466
Conversation
''' | ||
if (x != ignore) { | ||
int w_ind[] = {x, i % n_out}; | ||
ptrdiff_t w_ind[] = {x, i % n_out}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. (ptrdiff_t
-> S
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
@@ -45,7 +43,7 @@ def binary_accuracy(y, t): | |||
|
|||
t (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ | |||
:class:`cupy.ndarray`): | |||
Array holding an int32 vector of ground truth labels. | |||
Array holding an signed integer vector of ground truth labels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an -> a
@@ -75,8 +75,8 @@ def sigmoid_cross_entropy( | |||
x (Variable): A variable object holding a matrix whose (i, j)-th | |||
element indicates the unnormalized log probability of the j-th unit | |||
at the i-th example. | |||
t (Variable): Variable holding an int32 vector of ground truth labels. | |||
If ``t[i] == -1``, corresponding ``x[i]`` is ignored. | |||
t (Variable): Variable holding an signed integer vector of ground truth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. (an -> a)
@@ -294,7 +294,7 @@ def softmax_cross_entropy( | |||
dimensions is greater than 2. | |||
t (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ | |||
:class:`cupy.ndarray`): | |||
Variable holding an :class:`numpy.int32` vector of ground truth | |||
Variable holding an signed integer vector of ground truth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. (an -> a)
{'shape': (3,), 'dtype': 'f', 'axis': -1, 'inv': True}, | ||
{'shape': (3, 4), 'dtype': 'd', 'axis': 1, 'inv': True}, | ||
{'shape': (3, 4, 5), 'dtype': 'f', 'axis': 2, 'inv': False}], | ||
[{'label_dtype': numpy.int}, {'label_dtype': numpy.int32}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about other int types (int8, int16, int64)?
@testing.parameterize(*testing.product({ | ||
'reduce': ['no', 'mean'], | ||
'norm': ['L1', 'L2'], | ||
'label_dtype': [numpy.int, numpy.int32], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. (other int types)
# too large shape causes int32 -> float64 issue | ||
{'shape': (65536, 1), 'normalize': False}, | ||
{'shape': (65536, 1), 'normalize': False, 'label_dtype': numpy.int32}, | ||
{'shape': (8, 7), 'normalize': True, 'label_dtype': numpy.int} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding some cases to test other int types?
@@ -21,6 +21,7 @@ | |||
'dtype': [numpy.float32], | |||
'weight_apply': [False, True], | |||
'enable_double_backprop': [False, True], | |||
'label_dtype': [numpy.int32], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. (other int types)
@@ -29,6 +30,7 @@ | |||
'dtype': [numpy.float16, numpy.float32, numpy.float64], | |||
'weight_apply': [False, True], | |||
'enable_double_backprop': [False, True], | |||
'label_dtype': [numpy.int, numpy.int32], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. (other int types)
'T gy, int32 x, int32 n_out', 'raw T gW', | ||
'int w_ind[] = {x, i % n_out}; atomicAdd(&gW[w_ind], gy)', | ||
'T gy, S x, S n_out', 'raw T gW', | ||
'ptrdiff_t w_ind[] = {x, i % n_out};' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about ptrdiff_t
-> S
? It's guaranteed to be representable by S
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptrdiff_t
is better for CuPy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I agree with you because it's used as an index.
Please fix AppVeyor result. |
979b518
to
b9840bf
Compare
I fixed. |
b9840bf
to
10cc5b6
Compare
I resolved conflict. |
@@ -43,6 +42,7 @@ def check_type_forward(self, in_types): | |||
def forward(self, inputs): | |||
xp = cuda.get_array_module(*inputs) | |||
y, t = inputs | |||
t = t.astype('i', copy=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this conversion necessary?
I think the actual type for 'i'
is environment-dependent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Some Numpy functions require int32 on Windows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Can you write a brief comment on this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
Please fix conflict. |
b3ca2cd
to
3352886
Compare
I resolved conflict. |
(Jenkins has passed. Waiting for travis) |
LGTM! |
Relax int type restriction
Relax int type restriction
Current int type restriction is too strict.
This PR is relax it.