Skip to content

Commit

Permalink
Supporting negative axes for all existing onnx ops (onnx#2281)
Browse files Browse the repository at this point in the history
* Added negative axes for slice and squeeze opset 11

* added negative axes support for squeeze, unsqueeze, flatten

* added support for negative axes to all the existing ops

* fixed minor if condition missed for axis attr in flatten

* fixed test name for flatten with negative axes

* updated unsqueeze and softmax tests with fix for failures

* fixed typo

* Updating Split op documentations and version

* fixed typo in unsqueeze model

* fixed dim check for unsqueeze

* fixed type cast

* test fix for build failure

* updating onnx model for unsqueeze test

* fixed minor error in type casting
  • Loading branch information
neginraoof authored and wschin committed Sep 10, 2019
1 parent 81a9503 commit e17df4e
Show file tree
Hide file tree
Showing 190 changed files with 5,620 additions and 284 deletions.
1,375 changes: 1,284 additions & 91 deletions docs/Changelog.md

Large diffs are not rendered by default.

835 changes: 767 additions & 68 deletions docs/Operators.md

Large diffs are not rendered by default.

695 changes: 657 additions & 38 deletions docs/TestCoverage.md

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions onnx/backend/test/case/node/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,23 @@ def export_default_axes_keepdims(): # type: () -> None
# result's shape: [1, 3, 4]
result = argmax_use_numpy(data, keepdims=keepdims)
expect(node, inputs=[data], outputs=[result], name='test_argmax_default_axis_random')

@staticmethod
def export_negative_axis_keepdims(): # type: () -> None
data = np.array([[2, 1], [3, 10]], dtype=np.float32)
axis = -1
keepdims = 1
node = onnx.helper.make_node(
'ArgMax',
inputs=['data'],
outputs=['result'],
axis=axis,
keepdims=keepdims)
# result: [[0], [1]]
result = argmax_use_numpy(data, axis=axis, keepdims=keepdims)
expect(node, inputs=[data], outputs=[result], name='test_argmax_negative_axis_keepdims_example')

data = np.random.uniform(-10, 10, [2, 3, 4]).astype(np.float32)
# result's shape: [2, 3, 1]
result = argmax_use_numpy(data, axis=axis, keepdims=keepdims)
expect(node, inputs=[data], outputs=[result], name='test_argmax_negative_axis_keepdims_random')
20 changes: 20 additions & 0 deletions onnx/backend/test/case/node/argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,23 @@ def export_default_axes_keepdims(): # type: () -> None
# result's shape: [1, 3, 4]
result = argmin_use_numpy(data, keepdims=keepdims)
expect(node, inputs=[data], outputs=[result], name='test_argmin_default_axis_random')

@staticmethod
def export_negative_axis_keepdims(): # type: () -> None
data = np.array([[2, 1], [3, 10]], dtype=np.float32)
axis = -1
keepdims = 1
node = onnx.helper.make_node(
'ArgMin',
inputs=['data'],
outputs=['result'],
axis=axis,
keepdims=keepdims)
# result: [[1], [0]]
result = argmin_use_numpy(data, axis=axis, keepdims=keepdims)
expect(node, inputs=[data], outputs=[result], name='test_argmin_negative_axis_keepdims_example')

data = np.random.uniform(-10, 10, [2, 3, 4]).astype(np.float32)
# result's shape: [2, 3, 1]
result = argmin_use_numpy(data, axis=axis, keepdims=keepdims)
expect(node, inputs=[data], outputs=[result], name='test_argmin_negative_axis_keepdims_random')
18 changes: 18 additions & 0 deletions onnx/backend/test/case/node/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,21 @@ def export_compress_default_axis(): # type: () -> None

expect(node, inputs=[input, condition.astype(np.bool)], outputs=[output],
name='test_compress_default_axis')

@staticmethod
def export_compress_negative_axis(): # type: () -> None
node = onnx.helper.make_node(
'Compress',
inputs=['input', 'condition'],
outputs=['output'],
axis=-1,
)
input = np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)
condition = np.array([0, 1])
output = np.compress(condition, input, axis=-1)
# print(output)
#[[ 2.]
# [ 4.]
# [ 6.]]
expect(node, inputs=[input, condition.astype(np.bool)], outputs=[output],
name='test_compress_negative_axis')
12 changes: 12 additions & 0 deletions onnx/backend/test/case/node/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,15 @@ def export(): # type: () -> None
output = np.concatenate(values, i)
expect(node, inputs=[v for v in values], outputs=[output],
name='test_concat_' + test_case + '_axis_' + str(i))

for i in range(-len(values[0].shape), 0):
in_args = ['value' + str(k) for k in range(len(values))]
node = onnx.helper.make_node(
'Concat',
inputs=[s for s in in_args],
outputs=['output'],
axis=i
)
output = np.concatenate(values, i)
expect(node, inputs=[v for v in values], outputs=[output],
name='test_concat_' + test_case + '_axis_negative_' + str(abs(i)))
13 changes: 13 additions & 0 deletions onnx/backend/test/case/node/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,16 @@ def export_cumsum_2d_axis_1(): # type: () -> None
y = np.array([1., 3., 6., 4., 9., 15.]).astype(np.float64).reshape((2, 3))
expect(node, inputs=[x, axis], outputs=[y],
name='test_cumsum_2d_axis_1')

@staticmethod
def export_cumsum_2d_negative_axis(): # type: () -> None
node = onnx.helper.make_node(
'CumSum',
inputs=['x', 'axis'],
outputs=['y'],
)
x = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float64).reshape((2, 3))
axis = np.array([-1]).astype(np.int32)
y = np.array([1., 3., 6., 4., 9., 15.]).astype(np.float64).reshape((2, 3))
expect(node, inputs=[x, axis], outputs=[y],
name='test_cumsum_2d_negative_axis')
18 changes: 18 additions & 0 deletions onnx/backend/test/case/node/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,21 @@ def export_flatten_with_default_axis(): # type: () -> None
b = np.reshape(a, new_shape)
expect(node, inputs=[a], outputs=[b],
name='test_flatten_default_axis')

@staticmethod
def export_flatten_negative_axis(): # type: () -> None
shape = (2, 3, 4, 5)
a = np.random.random_sample(shape).astype(np.float32)

for i in range(-len(shape), 0):
node = onnx.helper.make_node(
'Flatten',
inputs=['a'],
outputs=['b'],
axis=i,
)

new_shape = (np.prod(shape[0:i]).astype(int), -1)
b = np.reshape(a, new_shape)
expect(node, inputs=[a], outputs=[b],
name='test_flatten_negative_axis' + str(abs(i)))
10 changes: 10 additions & 0 deletions onnx/backend/test/case/node/hardmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@ def hardmax_2d(x): # type: (np.ndarray) -> np.ndarray
y = hardmax_2d(x.reshape(12, 5)).reshape(3, 4, 5)
expect(node, inputs=[x], outputs=[y],
name='test_hardmax_axis_2')

node = onnx.helper.make_node(
'Hardmax',
inputs=['x'],
outputs=['y'],
axis=-1,
)
y = hardmax_2d(x.reshape(12, 5)).reshape(3, 4, 5)
expect(node, inputs=[x], outputs=[y],
name='test_hardmax_negative_axis')
10 changes: 10 additions & 0 deletions onnx/backend/test/case/node/logsoftmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def logsoftmax_2d(x): # type: (np.ndarray) -> np.ndarray
y = logsoftmax_2d(x.reshape(12, 5)).reshape(3, 4, 5)
expect(node, inputs=[x], outputs=[y],
name='test_logsoftmax_axis_2')

node = onnx.helper.make_node(
'LogSoftmax',
inputs=['x'],
outputs=['y'],
axis=-1,
)
y = logsoftmax_2d(x.reshape(12, 5)).reshape(3, 4, 5)
expect(node, inputs=[x], outputs=[y],
name='test_logsoftmax_negative_axis')
20 changes: 20 additions & 0 deletions onnx/backend/test/case/node/onehot.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,23 @@ def export_with_axis(): # type: () -> None
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
expect(node, inputs=[indices, depth, values], outputs=[y], name='test_onehot_with_axis')

@staticmethod
def export_with_negative_axis(): # type: () -> None
axisValue = -2
on_value = 3
off_value = 1
output_type = np.float32
node = onnx.helper.make_node(
'OneHot',
inputs=['indices', 'depth', 'values'],
outputs=['y'],
axis=axisValue
)
indices = np.array([[1, 9],
[2, 4]], dtype=np.float32)
depth = np.array([10], dtype=np.float32)
values = np.array([off_value, on_value], dtype=output_type)
y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
y = y * (on_value - off_value) + off_value
expect(node, inputs=[indices, depth, values], outputs=[y], name='test_onehot_with_negative_axis')
14 changes: 14 additions & 0 deletions onnx/backend/test/case/node/reduce_log_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,17 @@ def export_keepdims(): # type: () -> None
reduced = np.log(np.sum(data, keepdims=True))
expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_log_sum_default')

@staticmethod
def export_negative_axes_keepdims(): # type: () -> None
node = onnx.helper.make_node(
'ReduceLogSum',
inputs=['data'],
outputs=["reduced"],
axes=[-2]
)
data = np.random.ranf([3, 4, 5]).astype(np.float32)
reduced = np.log(np.sum(data, axis=(-2), keepdims=True))
# print(reduced)
expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_log_sum_negative_axes')
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,39 @@ def export_default_axes_keepdims(): # type: () -> None
keepdims=keepdims == 1))
expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_log_sum_exp_default_axes_keepdims_random')

@staticmethod
def export_negative_axes_keepdims(): # type: () -> None
shape = [3, 2, 2]
axes = [-2]
keepdims = 1
node = onnx.helper.make_node(
'ReduceLogSumExp',
inputs=['data'],
outputs=['reduced'],
axes=axes,
keepdims=keepdims
)

data = np.array(
[[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]],
dtype=np.float32)
reduced = np.log(np.sum(np.exp(data),
axis=tuple(axes),
keepdims=keepdims == 1))
# print(reduced)
# [[[20., 2.31326175]]
# [[40.00004578, 2.31326175]]
# [[60.00671387, 2.31326175]]]

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_log_sum_exp_negative_axes_keepdims_example')

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.log(np.sum(np.exp(data),
axis=tuple(axes),
keepdims=keepdims == 1))

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_log_sum_exp_negative_axes_keepdims_random')
32 changes: 32 additions & 0 deletions onnx/backend/test/case/node/reducel1.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,35 @@ def export_default_axes_keepdims(): # type: () -> None

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_l1_default_axes_keepdims_random')

@staticmethod
def export_negative_axes_keepdims(): # type: () -> None
shape = [3, 2, 2]
axes = [-1]
keepdims = 1

node = onnx.helper.make_node(
'ReduceL1',
inputs=['data'],
outputs=['reduced'],
axes=axes,
keepdims=keepdims
)

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
#[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sum(a=np.abs(data), axis=tuple(axes), keepdims=keepdims == 1)
# print(reduced)
#[[[3.], [7.]], [[11.], [15.]], [[19.], [23.]]]

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_l1_negative_axes_keep_dims_example')

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sum(a=np.abs(data), axis=tuple(axes), keepdims=keepdims == 1)

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_l1_negative_axes_keep_dims_random')
36 changes: 36 additions & 0 deletions onnx/backend/test/case/node/reducel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,39 @@ def export_default_axes_keepdims(): # type: () -> None

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_l2_default_axes_keepdims_random')

@staticmethod
def export_negative_axes_keepdims(): # type: () -> None
shape = [3, 2, 2]
axes = [-1]
keepdims = 1

node = onnx.helper.make_node(
'ReduceL2',
inputs=['data'],
outputs=['reduced'],
axes=axes,
keepdims=keepdims
)

data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
# print(data)
#[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]]

reduced = np.sqrt(np.sum(
a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1))
# print(reduced)
#[[[2.23606798], [5.]]
# [[7.81024968], [10.63014581]]
# [[13.45362405], [16.2788206 ]]]

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_l2_negative_axes_keep_dims_example')

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.sqrt(np.sum(
a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1))

expect(node, inputs=[data], outputs=[reduced],
name='test_reduce_l2_negative_axes_keep_dims_random')
28 changes: 28 additions & 0 deletions onnx/backend/test/case/node/reducemax.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,31 @@ def export_default_axes_keepdims(): # type: () -> None
reduced = np.maximum.reduce(data, axis=axes, keepdims=keepdims == 1)

expect(node, inputs=[data], outputs=[reduced], name='test_reduce_max_default_axes_keepdims_random')

@staticmethod
def export_negative_axes_keepdims(): # type: () -> None
shape = [3, 2, 2]
axes = [-2]
keepdims = 1

node = onnx.helper.make_node(
'ReduceMax',
inputs=['data'],
outputs=['reduced'],
axes=axes,
keepdims=keepdims)

data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
reduced = np.maximum.reduce(data, axis=tuple(axes), keepdims=keepdims == 1)
# print(reduced)
#[[[20., 2.]]
# [[40., 2.]]
# [[60., 2.]]]

expect(node, inputs=[data], outputs=[reduced], name='test_reduce_max_negative_axes_keepdims_example')

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.maximum.reduce(data, axis=tuple(axes), keepdims=keepdims == 1)

expect(node, inputs=[data], outputs=[reduced], name='test_reduce_max_negative_axes_keepdims_random')
28 changes: 28 additions & 0 deletions onnx/backend/test/case/node/reducemean.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,31 @@ def export_default_axes_keepdims(): # type: () -> None
reduced = np.mean(data, axis=axes, keepdims=keepdims == 1)

expect(node, inputs=[data], outputs=[reduced], name='test_reduce_mean_default_axes_keepdims_random')

@staticmethod
def export_negative_axes_keepdims(): # type: () -> None
shape = [3, 2, 2]
axes = [-2]
keepdims = 1

node = onnx.helper.make_node(
'ReduceMean',
inputs=['data'],
outputs=['reduced'],
axes=axes,
keepdims=keepdims)

data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
# print(reduced)
# [[[12.5, 1.5]]
# [[35., 1.5]]
# [[57.5, 1.5]]]

expect(node, inputs=[data], outputs=[reduced], name='test_reduce_mean_negative_axes_keepdims_example')

np.random.seed(0)
data = np.random.uniform(-10, 10, shape).astype(np.float32)
reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)

expect(node, inputs=[data], outputs=[reduced], name='test_reduce_mean_negative_axes_keepdims_random')
Loading

0 comments on commit e17df4e

Please sign in to comment.