Skip to content

Commit

Permalink
[TFLite] Quantized version of unit test for Dense (apache#7113)
Browse files Browse the repository at this point in the history
Added quantized version of unit test for FullyConnected/Dense
Added check for -1 in case if bias not supplied
  • Loading branch information
d-smirnov authored and trevor-m committed Jan 21, 2021
1 parent 40e2e90 commit 71c1566
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
19 changes: 10 additions & 9 deletions python/tvm/relay/frontend/tflite.py
Expand Up @@ -982,7 +982,7 @@ def convert_concatenation(self, op):

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
in_exprs = [self.get_tensor_expr(_) for _ in input_tensors]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
Expand Down Expand Up @@ -1830,14 +1830,15 @@ def convert_fully_connected(self, op):
# if we have bias
if len(input_tensors) == 3:
bias_tensor = input_tensors[2]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
)
out = _op.nn.bias_add(out, bias_expr)
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
bias_expr = self.exp_tab.new_const(
self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
)
out = _op.nn.bias_add(out, bias_expr)

# Finally if the dense is quantized. Add a requantize at the end.
if output_tensor.qnn_params:
Expand Down
83 changes: 63 additions & 20 deletions tests/python/frontend/tflite/test_forward.py
Expand Up @@ -3342,9 +3342,9 @@ def test_forward_sparse_to_dense():
#######################################################################
# Fully Connected
# ---------------


def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in_size=None):
def _test_fully_connected(
tensor_in_sizes, const_input, filter_in_sizes, bias_in_size=None, quantized=False
):
""" One iteration of fully connected """

total_size_1 = np.prod(tensor_in_sizes)
Expand All @@ -3356,42 +3356,85 @@ def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in

# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = np.arange(1, total_size_1 + 1, dtype=np.float32)
filter_array = np.arange(1, total_size_2 + 1, dtype=np.float32)
data_array = np.arange(1, total_size_1 + 1, dtype=np.uint8 if quantized else np.float32)
filter_array = np.arange(1, total_size_2 + 1, dtype=np.uint8 if quantized else np.float32)
in_name = "input"

with tf.Graph().as_default():
in_name = "input"
in_data = (
constant_op.constant(data_array, shape=tensor_in_sizes, dtype=np.float32, name=in_name)
if const_input
else array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32, name=in_name)
)

in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype=np.float32)

# reshape N H W C into N H*W*C
in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])

out = math_ops.mat_mul(in_data_reshape, in_filter)
data_array = np.reshape(data_array, tensor_in_sizes)

# if we have bias
if bias_in_size:
assert bias_in_size[0] == filter_in_sizes[1], "bias and filter size are mismatched"
bias_array = np.arange(1, bias_in_size[0] + 1, dtype=np.float32)
bias_array = np.arange(
1, bias_in_size[0] + 1, dtype=np.uint8 if quantized else np.float32
)
in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype=np.float32)
out = nn_ops.bias_add(out, in_bias)

data_array = np.reshape(data_array, tensor_in_sizes).astype(np.float32)
compare_tflite_with_tvm(data_array, [] if const_input else in_data.name, [in_data], [out])
if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(
in_data, min=-100, max=100, name="inq_0"
)
input_range = {"inq_0": (-100, 100)}
inq_filter = tf.quantization.fake_quant_with_min_max_args(
in_filter, min=-100, max=100, name="inq_1"
)
input_range = {"inq_0": (-100, 100), "inq_1": (-100, 100)}
# reshape N H W C into N H*W*C
inq_data_reshape = array_ops.reshape(inq_data, [tensor_in_sizes[0], -1])
out = math_ops.mat_mul(inq_data_reshape, inq_filter)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")

# if we have bias
if bias_in_size:
out = nn_ops.bias_add(out, in_bias)

compare_tflite_with_tvm(
data_array,
inq_data.name,
[inq_data],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=True,
)
else:
# reshape N H W C into N H*W*C
in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])
out = math_ops.mat_mul(in_data_reshape, in_filter)

# if we have bias
if bias_in_size:
out = nn_ops.bias_add(out, in_bias)

compare_tflite_with_tvm(
data_array, in_data.name, [in_data], [out], experimental_new_converter=True
)


def test_forward_fully_connected():
""" Fully Connected """
for const_input in [False, True]:
_test_fully_connected([1, 1, 1, 150], const_input, [150, 100])
_test_fully_connected([1, 1, 1, 150], const_input, [150, 100], [100])
_test_fully_connected([5, 1, 1, 150], const_input, [150, 100])
_test_fully_connected([5, 1, 1, 150], const_input, [150, 100], [100])
for input_shape, weight_shape, bias_shape in [
([1, 4], [4, 4], None),
([1, 4], [4, 4], [4]),
([1, 1, 1, 5], [5, 5], None),
([1, 1, 10], [10, 103], None),
([1, 1, 1, 150], [150, 100], None),
([1, 1, 1, 150], [150, 100], None),
([1, 1, 1, 150], [150, 100], [100]),
([5, 1, 1, 150], [150, 100], None),
([5, 1, 1, 150], [150, 100], [100]),
]:
for const_input in [False, True]:
for quantized in [False, True]:
_test_fully_connected(input_shape, const_input, weight_shape, bias_shape, quantized)


#######################################################################
Expand Down

0 comments on commit 71c1566

Please sign in to comment.