Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Sep 6, 2023
1 parent ced3dbb commit 82f40a8
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 6 deletions.
11 changes: 9 additions & 2 deletions tests/kernels/test_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ class GatherTest : public KernelTest,

auto shape = GetShapeArray("lhs_shape");
auto indices_shape = GetShapeArray("indices_shape");
auto indices_value = GetDataArray("indices_value");
auto value = GetNumber("axis");
auto typecode = GetDataType("lhs_type");

input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, -1, -1};
size_t indices_value_size = indices_value.size();
auto *indices_array =
(int64_t *)malloc(indices_value_size * sizeof(int64_t));
std::copy(indices_value.begin(), indices_value.end(), indices_array);
indices = hrt::create(dt_int64, indices_shape,
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
indices_value_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

Expand Down Expand Up @@ -114,17 +118,20 @@ int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(indices_shape, l)
FOR_LOOP(indices_value, h)
FOR_LOOP(axis, j)
FOR_LOOP(lhs_type, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(indices_shape, l)
SPLIT_ELEMENT(indices_value, h)
SPLIT_ELEMENT(axis, j)
SPLIT_ELEMENT(lhs_type, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
1 change: 1 addition & 0 deletions tests/kernels/test_gather.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"lhs_shape":[[2, 3, 5, 7], [2, 2], [2, 3, 1], [5, 5, 7, 7], [11]],
"indices_shape":[[4], [2, 2], [4, 1]],
"axis":[0, 1, -1, 2, 3, -2, -3, -4],
"indices_value": [[0, 0, -1, -1]],
"lhs_type":["dt_float32", "dt_int8", "dt_int32", "dt_uint8", "dt_int16", "dt_uint16", "dt_uint32", "dt_uint64", "dt_int64", "dt_float16", "dt_float64", "dt_bfloat16", "dt_boolean"]
}
11 changes: 9 additions & 2 deletions tests/kernels/test_gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ class GatherElementsTest : public KernelTest,

auto shape = GetShapeArray("lhs_shape");
auto indices_shape = GetShapeArray("indices_shape");
auto indices_value = GetDataArray("indices_value");
auto value = GetNumber("axis");
auto typecode = GetDataType("lhs_type");

input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, 1, 1};
size_t indices_value_size = indices_value.size();
auto *indices_array =
(int64_t *)malloc(indices_value_size * sizeof(int64_t));
std::copy(indices_value.begin(), indices_value.end(), indices_array);
indices = hrt::create(dt_int64, indices_shape,
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
indices_value_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

Expand Down Expand Up @@ -112,17 +116,20 @@ int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(indices_shape, l)
FOR_LOOP(indices_value, h)
FOR_LOOP(axis, j)
FOR_LOOP(lhs_type, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(indices_shape, l)
SPLIT_ELEMENT(indices_value, h)
SPLIT_ELEMENT(axis, j)
SPLIT_ELEMENT(lhs_type, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
1 change: 1 addition & 0 deletions tests/kernels/test_gather_elements.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"lhs_shape":[[2, 2]],
"axis":[0],
"indices_shape":[[2, 2], [4, 1]],
"indices_value": [[0, 0, 1, 1]],
"lhs_type":["dt_float32", "dt_int8", "dt_int32", "dt_uint8", "dt_int16", "dt_uint16", "dt_uint32", "dt_uint64", "dt_int64", "dt_float16", "dt_float64", "dt_bfloat16", "dt_boolean"]
}
11 changes: 9 additions & 2 deletions tests/kernels/test_gather_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ class GatherNDTest : public KernelTest,

auto shape = GetShapeArray("lhs_shape");
auto indices_shape = GetShapeArray("indices_shape");
auto indices_value = GetDataArray("indices_value");
auto value = GetNumber("axis");
auto typecode = GetDataType("lhs_type");

input = hrt::create(typecode, shape, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");
init_tensor(input);

int64_t indices_array[] = {0, 0, 0, 0};
size_t indices_value_size = indices_value.size();
auto *indices_array =
(int64_t *)malloc(indices_value_size * sizeof(int64_t));
std::copy(indices_value.begin(), indices_value.end(), indices_array);
indices = hrt::create(dt_int64, indices_shape,
{reinterpret_cast<gsl::byte *>(indices_array),
sizeof(indices_array)},
indices_value_size * sizeof(int64_t)},
true, host_runtime_tensor::pool_cpu_only)
.expect("create tensor failed");

Expand Down Expand Up @@ -111,17 +115,20 @@ int main(int argc, char *argv[]) {
READY_TEST_CASE_GENERATE()
FOR_LOOP(lhs_shape, i)
FOR_LOOP(indices_shape, l)
FOR_LOOP(indices_value, h)
FOR_LOOP(axis, j)
FOR_LOOP(lhs_type, k)
SPLIT_ELEMENT(lhs_shape, i)
SPLIT_ELEMENT(indices_shape, l)
SPLIT_ELEMENT(indices_value, h)
SPLIT_ELEMENT(axis, j)
SPLIT_ELEMENT(lhs_type, k)
WRITE_SUB_CASE()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()
FOR_LOOP_END()

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
1 change: 1 addition & 0 deletions tests/kernels/test_gather_nd.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"lhs_shape":[[3, 5], [2, 2], [2, 3, 1], [5, 5, 7, 7]],
"axis":[0],
"indices_shape":[[2, 2], [4, 1]],
"indices_value": [[0, 0, 0, 0]],
"lhs_type":["dt_float32", "dt_int8", "dt_int32", "dt_uint8", "dt_int16", "dt_uint16", "dt_uint32", "dt_uint64", "dt_int64", "dt_float16", "dt_float64", "dt_bfloat16", "dt_boolean"]
}

0 comments on commit 82f40a8

Please sign in to comment.