Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
75c4f0e
Add Opset v10 support for MaxPool
sdmonov Aug 25, 2019
48c1b14
Fixed few issues with the code
sdmonov Aug 25, 2019
c138bbf
Fixed Python 2 -inf use
sdmonov Aug 25, 2019
ee4c7d8
Fixed Python 2 -inf use
sdmonov Aug 25, 2019
ae4186c
Cleaned up the code and added comments. Added on more test case for m…
sdmonov Aug 28, 2019
b06b54f
Small code improvement and regenerated the ops version and doc
sdmonov Aug 29, 2019
611be2b
N-D max pooling support
sdmonov Sep 10, 2019
8bd21b8
Regenerated opset_version and docs
sdmonov Sep 10, 2019
1ed54c0
Fixed padding support for auto_pad
sdmonov Sep 11, 2019
c052a1b
Fixed small code formatting
sdmonov Sep 11, 2019
d32d80e
Fixed bug in the pooling algorithm
sdmonov Sep 11, 2019
4027905
Python 2.7 bug
sdmonov Sep 11, 2019
bdb2481
Added checks for the opset version in unit tests
sdmonov Sep 11, 2019
2fececc
Convert the input to tensor of not a tensorflow tensor yet
sdmonov Sep 17, 2019
13a3af2
Add Opset v10 support for MaxPool
sdmonov Aug 25, 2019
c9bdcbb
Fixed few issues with the code
sdmonov Aug 25, 2019
e152a31
Fixed Python 2 -inf use
sdmonov Aug 25, 2019
49b7413
Fixed Python 2 -inf use
sdmonov Aug 25, 2019
8e16c82
Cleaned up the code and added comments. Added on more test case for m…
sdmonov Aug 28, 2019
84602cf
Small code improvement and regenerated the ops version and doc
sdmonov Aug 29, 2019
4746240
N-D max pooling support
sdmonov Sep 10, 2019
9df80d1
Fixed padding support for auto_pad
sdmonov Sep 11, 2019
bffbd16
Fixed small code formatting
sdmonov Sep 11, 2019
687fac8
Fixed bug in the pooling algorithm
sdmonov Sep 11, 2019
ce92fbd
Python 2.7 bug
sdmonov Sep 11, 2019
0036344
Added checks for the opset version in unit tests
sdmonov Sep 11, 2019
2aa96a0
Convert the input to tensor of not a tensorflow tensor yet
sdmonov Sep 17, 2019
df97f1f
Rebase with master
sdmonov Sep 26, 2019
edb6b7f
Multiple changes
sdmonov Oct 17, 2019
c86ea74
Regenerated the ops and docs and added v11 support
sdmonov Oct 17, 2019
21feefa
Fixed few bugs when tensor shape is not known
sdmonov Oct 17, 2019
487e24e
Merge branch 'master' into maxpool_v10
sdmonov Oct 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Changing to False is strongly discouraged.
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.


`logging_level` : The logging level, default is INFO. Change it to DEBUG
to see more conversion details or to WARNING to see less


_returns_:

A TensorflowRep class object representing the ONNX model
Expand Down
6 changes: 5 additions & 1 deletion doc/CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ optional arguments:
More information: `onnx-tf convert -h`
```
usage: onnx-tf [-h] --infile INFILE --outfile OUTFILE [--device DEVICE]
[--strict STRICT]
[--strict STRICT] [--logging_level LOGGING_LEVEL]

This is the converter for converting protocol buffer between tf and onnx.

Expand All @@ -47,4 +47,8 @@ backend arguments (onnx -> tf):
Changing to False is strongly discouraged. Currently,
the strict flag only affects the behavior of MaxPool
and AveragePool ops. (from onnx_tf.backend.prepare)
--logging_level LOGGING_LEVEL
The logging level, default is INFO. Change it to DEBUG
to see more conversion details or to WARNING to see
less (from onnx_tf.backend.prepare)
```
2 changes: 1 addition & 1 deletion doc/support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Notes:
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|8|
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_red_triangle:|**11**:small_red_triangle:|
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|**11**:small_orange_diamond:|
|MaxRoiPool|**1**:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|1:small_red_triangle:|
|MaxUnpool|-|-|-|-|-|-|-|-|**9**|9|**11**:small_red_triangle:|
|Mean|**1**|1|1|1|1|**6**|6|**8**|8|8|8|
Expand Down
190 changes: 190 additions & 0 deletions onnx_tf/common/pooling_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from __future__ import division

from collections import namedtuple
from numpy import inf
import numpy as np
import tensorflow as tf

import itertools


pad_ops = namedtuple("pad_ops",
["max_op", "ceil_op", "floor_op", "cast_int_op"])

pad_numpy_ops = pad_ops(np.maximum, np.ceil, np.floor,
lambda arr: arr.astype(np.int64))
pad_tf_ops = pad_ops(tf.maximum, tf.math.ceil, tf.math.floor,
lambda tensor: tf.cast(tensor, tf.int64))


def calc_pads_same(in_spatial_shape, kernel_shape, strides,
dilations, padding, padding_ops=pad_numpy_ops,
pads_order=1):
"""
Calculates the SAME paddings that need to be added to the input

Args:
in_spatial_shape: input spatial shape
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis
padding: padding to calculate: SAME_UPPER or
SAME_LOWER
padding_ops: namedtuple with ops to be used during
calculations. there are two sets of ops
defined pad_numpy_ops and pad_tf_ops with
numpy and tensorflow ops
pads_order: order of returned pads. possible options are:
1 - b1, b2, ..., bn, e1, e2, ..., en
2 - b1, e1, b2, e2, ..., bn, en
where n = len(kernel_shape) * 2,
b1, b2, ..., bn define pads at the begging of
axis
e1, e2, ..., en define pads at the end of
axis
Return:
pads: array with calculated pads. the order of the
values is determined by `pads_order`

"""
spatial_size = len(kernel_shape)
pads = [0] * (spatial_size * 2)
for i in range(spatial_size):
in_size = in_spatial_shape[i]
filter_size = (kernel_shape[i] - 1) * dilations[i] + 1

out_size = padding_ops.ceil_op(in_size / strides[i])
out_size = padding_ops.cast_int_op(out_size)
pad_along_axis = \
padding_ops.max_op((out_size - 1) * strides[i] +
filter_size - in_size, 0)
if padding.lower() == "same_lower":
pad_op = padding_ops.ceil_op
else:
pad_op = padding_ops.floor_op
pad_begin = pad_op(pad_along_axis / 2)

pad_begin = padding_ops.cast_int_op(pad_begin)
pad_along_axis = padding_ops.cast_int_op(pad_along_axis)

pad_end = pad_along_axis - pad_begin

pads[i * pads_order] = pad_begin
pads[i * pads_order +
(spatial_size if pads_order == 1 else 1)] = pad_end

return pads


def py_maxpool(input, kernel_shape, strides=None, dilations=None,
padding=None, ceil_mode=False):
"""
Implementation of MaxPool operation in Python
Args:
input: input N-D data array in NC* format
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis of filter
padding: padding for the beginning and ending along each
spatial axis. `padding` format should be as follow
[x1_begin, x2_begin...x1_end, x2_end,...]
ceil_mode: whether to use ceil or floor (default) to compute
the output shape.
Return:
pooled: output data from max pooling across the input
ind: indices of the selected max values from the input
"""

def _pooling_output_shape(input_size, ksize, stride,
dilation, pad, ceil_mode):
output_size = (input_size + pad - ((ksize - 1) * dilation + 1) +
((stride-1) if ceil_mode else 0)) // stride + 1
if (pad):
if ((output_size - 1) * stride >= input_size + pad):
output_size -= 1
return output_size

input_shape = np.shape(input)
inp_sp_shape = input_shape[2:]

def _loop_over_output(batch, channel):
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
for counters in itertools.product(*dims):
input_ranges = []
for dim in range(spatial_size):
dim_start = \
counters[dim] * strides[dim] - pads[dim * 2]
dim_end = \
min(dim_start + (kernel_shape[dim] - 1) * dilations[dim]
+ 1, inp_sp_shape[dim])
while dim_start < 0:
dim_start += dilations[dim]

cur_range = [i for i in range(dim_start,
dim_end, dilations[dim])]
input_ranges.append(cur_range)
maxval = -inf
maxind = -1
for input_ind in itertools.product(*input_ranges):
ind = (batch, channel) + input_ind
val = input[ind]
if val > maxval:
maxval = val
ind = 0
for i in range(spatial_size):
coef = 1
for j in range(i+1, spatial_size):
coef *= inp_sp_shape[j]
ind += input_ind[i] * coef
maxind = ind
ind = (batch, channel) + counters
out_pool[ind] = maxval
out_ind[ind] = maxind

spatial_size = len(kernel_shape)

batch_size = input_shape[0]
channels_num = input_shape[1]

if strides is None:
strides = kernel_shape

if dilations is None:
dilations = [1] * spatial_size

if padding is None:
padding = [0] * spatial_size * 2

if type(padding) is not list:
if padding.lower().startswith("same"):
padding = calc_pads_same(inp_sp_shape, kernel_shape, strides,
dilations, padding)
else:
padding = [0] * spatial_size * 2

pads = []
pad_along_axis = []
output_sp_shape = []

for dim in range(spatial_size):
pads.append(padding[dim])
pads.append(padding[dim + spatial_size])
pad_along_axis.append(padding[dim] + padding[dim + spatial_size])

input_size = input_shape[dim + 2]
output_size = \
_pooling_output_shape(input_size, kernel_shape[dim],
strides[dim], dilations[dim],
pad_along_axis[dim], ceil_mode)
output_sp_shape.append(output_size)

out_pool = np.zeros([input_shape[0], input_shape[1]] +
output_sp_shape, input.dtype)
out_ind = np.zeros([input_shape[0], input_shape[1]] +
output_sp_shape, np.int64)

for batch in range(batch_size):
for channel in range(channels_num):
_loop_over_output(batch, channel)

return out_pool, out_ind
45 changes: 45 additions & 0 deletions onnx_tf/common/tf_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import tensorflow as tf
import numpy as np


def tf_shape(tensor):
"""
Helper function returning the shape of a Tensor.
The function will check for fully defined shape and will return
numpy array or if the shape is not fully defined will use tf.shape()
to return the shape as a Tensor.
"""
if tensor.shape.is_fully_defined():
return np.array(tensor.shape.as_list(), dtype=np.int64)
else:
return tf.shape(tensor, out_type=tf.int64)


def tf_product(a, b):
"""
Calculates the cartesian product of two column vectors a and b

Example:

a = [[1]
[2]
[3]]

b = [[0]
[1]]

result = [[1 0]
[1 1]
[2 0]
[2 1]
[3 0]
[3 1]]
"""
tile_a = tf.tile(a, [1, tf.shape(b)[0]])
tile_a = tf.expand_dims(tile_a, 2)
tile_a = tf.reshape(tile_a, [-1, 1])

b = tf.tile(b, [tf.shape(a)[0], 1])
b = tf.concat([tile_a, b], axis=1)

return b
Loading