Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5bdff7f
commit 4aae998
Showing
3 changed files
with
140 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from common.tf_layer_test_class import CommonTFLayerTest | ||
from common.utils.tf_utils import permute_nchw_to_nhwc | ||
|
||
|
||
class TestFloor(CommonTFLayerTest): | ||
def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_frontend): | ||
import tensorflow as tf | ||
|
||
tf.compat.v1.reset_default_graph() | ||
|
||
# Create the graph and model | ||
with tf.compat.v1.Session() as sess: | ||
x = tf.compat.v1.placeholder(dtype, x_shape, 'Input') | ||
res = tf.raw_ops.Floor(x=x) | ||
|
||
tf.compat.v1.global_variables_initializer() | ||
tf_net = sess.graph_def | ||
|
||
ref_net = None | ||
|
||
return tf_net, ref_net | ||
|
||
def _prepare_input(self, inputs_dict): | ||
for input in inputs_dict.keys(): | ||
inputs_dict[input] = np.array([0.1, 0.2, 0.5, 0.55, 0.9, -0.1, -0.6, -0.9]).astype(np.float32) | ||
return inputs_dict | ||
|
||
# TODO: implement tests for 2 Consts + Add | ||
|
||
test_data_1D = [ | ||
dict(x_shape=[8], dtype=np.float32), | ||
# dict(x_shape=[], dtype=np.int32), | ||
# dict(x_shape=[2], dtype=np.int64), | ||
# dict(x_shape=[2, 4, 5], dtype=np.int32), | ||
# dict(x_shape=[], dtype=np.float32), | ||
# dict(x_shape=[2], dtype=np.float64), | ||
# dict(x_shape=[2, 4, 5], dtype=np.float32), | ||
] | ||
|
||
@pytest.mark.parametrize("params", test_data_1D) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_tf_fe | ||
def test_add_placeholder_const_1D(self, params, ie_device, precision, ir_version, temp_dir, | ||
use_new_frontend): | ||
self._test(*self.create_add_placeholder_const_net(**params, ir_version=ir_version, | ||
use_new_frontend=use_new_frontend), | ||
ie_device, precision, ir_version, temp_dir=temp_dir, | ||
use_new_frontend=use_new_frontend) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters