-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of SelectItem
in ONNX-Chainer
#8450
Add support of SelectItem
in ONNX-Chainer
#8450
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative logic of F.select_item
looks smart! SGTM,
So I set the argument to the decorator as (9, 11). It will be great if you can take a look at it
In this case, support
decorator does not have to be set, but I agree to better looking.
|
||
one_1 = onnx.helper.make_tensor('one_1', onnx.TensorProto.FLOAT, [1], [1]) | ||
ones = gb.op('ConstantOfShape', [n_rows], value=one_1) | ||
row_idxs = gb.op('Squeeze', [gb.op('NonZero', [ones])]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add the comment of "the process is equivalent to get range"
@@ -155,6 +155,11 @@ | |||
'args': {'slices': (slice(None), slice(0, 1), slice(None, 2))}, | |||
'name': 'get_item_start_from_none'}, | |||
|
|||
# select_item |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you split this test from TestArrayOperators
, because this function is for unary args.
class TestSelectItem(ONNXModelTest):
def test_output(self):
class Model(chainer.Chain):
def forward(self, x, t):
return F.select_item(x, t)
model = Model()
x = input_generator.increasing(3, 3)
t = np.array([2, 1, 0], dtype=np.int32)
self.expect(
model, (x, t), expected_num_initializers=0,
skip_opset_version=(7, 8))
- set
t
as input, not to put unnecessary initializer - set skip opset version
@tkanmae Do you have special reason not to use |
@disktnk The reason I did not use GatherElements is that CoreML 3.0 does not support it. Thanks for the comments above. I will make changes as you suggested, hopefully in a few days. |
I see. We'll make another PR to use |
c2345e1
to
a3b6bf1
Compare
* Remove a test data entry for select_item() from TestArrayOperators. * Add a test case for select_item().
a3b6bf1
to
12376f5
Compare
Got it. I fixed the bug in testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LTGM!
Jenkins, test this please! |
Jenkins CI test (for commit 12376f5, target branch master) succeeded! |
thx! The travis fail is unrelated to this PR, so merge this manually. (The error has already resolved by another PR). |
SelectItem
in ONNX-Chainer
Jenkins CI test (for commit 12376f5, target branch master) succeeded! |
This Pull Request addresses #8449. The implementation encodes an equivalent of the following NumPy snippet:
I could have used
Range
to createrow_idxs
, but found that some of other ONNX conversion tools out there have not supported it yet. So I settled on the sequence ofConstantOfShape
followed byNonZero
to create the indices.I'm not sure if I correctly used
support
decorator. Among the ops used in the implementation,ConstantOfShape
is the newest one introduced in opset 9, and some ops have been updated in opset 11. So I set the argument to the decorator as(9, 11)
. It will be great if you can take a look at it.