Skip to content

Commit

Permalink
Merge branch 'master' into fix/gather_test
Browse files Browse the repository at this point in the history
  • Loading branch information
HeJunchao100813 committed Sep 6, 2023
2 parents 82f40a8 + 7d4c9a8 commit e968882
Show file tree
Hide file tree
Showing 26 changed files with 2,512 additions and 64 deletions.
42 changes: 33 additions & 9 deletions tests/kernels/kernel_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,10 @@ class KernelTest {
<< (unsigned)document.GetErrorOffset() << " "
<< GetParseError_En(document.GetParseError())
<< std::endl;
assert(document.IsObject());

if (!document.IsObject()) {
throw std::runtime_error("type error! it should be Object.");
}
}

void ParseJson(std::string js_str) {
Expand All @@ -1089,7 +1092,10 @@ class KernelTest {
<< (unsigned)_document.GetErrorOffset() << " "
<< GetParseError_En(_document.GetParseError())
<< std::endl;
assert(_document.IsObject());

if (!_document.IsObject()) {
throw std::runtime_error("type error! it should be Object.");
}
}

typecode_t Str2DataType(std::string type) {
Expand All @@ -1102,27 +1108,41 @@ class KernelTest {
}

int64_t GetNumber(const char *key) {
assert(_document[key].IsInt64());
if (!_document[key].IsInt64()) {
throw std::runtime_error("type error! it should be int64.");
}

return _document[key].GetInt64();
}

float GetFloatNumber(const char *key) {
assert(_document[key].IsDouble());
if (!_document[key].IsDouble()) {
throw std::runtime_error("type error! it should be double.");
}

return _document[key].GetFloat();
}

typecode_t GetDataType(const char *key) {
assert(_document[key].IsString());
if (!_document[key].IsString()) {
throw std::runtime_error("type error! it should be string.");
}

return Str2DataType(_document[key].GetString());
}

std::string GetString(const char *key) {
assert(_document[key].IsString());
if (!_document[key].IsString()) {
throw std::runtime_error("type error! it should be string.");
}

return _document[key].GetString();
}

dims_t GetShapeArray(const char *key) {
assert(_document[key].IsArray());
if (!_document[key].IsArray()) {
throw std::runtime_error("type error! it should be array.");
}

Value &array = _document[key];
size_t arraySize = array.Size();
Expand All @@ -1140,7 +1160,9 @@ class KernelTest {
}

std::vector<int64_t> GetDataArray(const char *key) {
assert(_document[key].IsArray());
if (!_document[key].IsArray()) {
throw std::runtime_error("type error! it should be array.");
}

Value &array = _document[key];
size_t arraySize = array.Size();
Expand All @@ -1158,7 +1180,9 @@ class KernelTest {
}

axes_t GetAxesArray(const char *key) {
assert(_document[key].IsArray());
if (!_document[key].IsArray()) {
throw std::runtime_error("type error! it should be array.");
}

Value &array = _document[key];
size_t arraySize = array.Size();
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_unary.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"i_shape":[[1, 3, 16, 16], [3, 16, 16], [3, 16, 1], [16, 16], [16, 1], [1, 16, 1], [16], [1], []],
"i_shape":[[1, 3, 16, 16], [3, 16, 16], [3, 16, 1], [16, 16], [16, 1], [1, 16, 1], [16], [1], [], [1, 30, 288, 288], [1, 30, 288], [288, 288], [288]],
"lhs_type":["dt_float32", "dt_int32", "dt_int64", "dt_float64", "dt_float16"]
}
148 changes: 148 additions & 0 deletions tests/kernels/test_unary_abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,152 @@ class UnaryTest : public KernelTest,

void TearDown() override { CLEAR_SUBCASE() }

void init_tensor(runtime_tensor &tensor) override {
auto dtype = tensor.datatype();
switch (dtype) {
case dt_int8: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<int8_t>(tensor, index) = static_cast<int8_t>(dis(gen));
return ok();
});
break;
}
case dt_int16: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<int16_t>(tensor, index) =
static_cast<int16_t>(dis(gen));
return ok();
});
break;
}
case dt_int32: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<int32_t>(tensor, index) = dis(gen);
return ok();
});
break;
}
case dt_int64: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<int64_t>(tensor, index) =
static_cast<int64_t>(dis(gen));
return ok();
});
break;
}
case dt_uint8: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<uint8_t>(tensor, index) =
static_cast<uint8_t>(dis(gen));
return ok();
});
break;
}
case dt_uint16: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<uint16_t>(tensor, index) =
static_cast<uint16_t>(dis(gen));
return ok();
});
break;
}
case dt_uint32: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<uint32_t>(tensor, index) =
static_cast<uint32_t>(dis(gen));
return ok();
});
break;
}
case dt_uint64: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint64_t> dis(-100000, 100000);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<uint64_t>(tensor, index) =
static_cast<uint64_t>(dis(gen));
return ok();
});
break;
}
case dt_float16: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(-10000.0f, 10000.0f);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<half>(tensor, index) = static_cast<half>(dis(gen));
return ok();
});
break;
}
case dt_float32: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(-100000.0f, 100000.0f);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<float>(tensor, index) = static_cast<float>(dis(gen));
return ok();
});
break;
}
case dt_float64: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<double> dis(-100000.0, 100000.0);
NNCASE_UNUSED auto res = kernels::stackvm::apply(
tensor.shape(),
[&](gsl::span<const size_t> index) -> result<void> {
get<double>(tensor, index) = static_cast<double>(dis(gen));
return ok();
});
break;
}
default: {
}
}
}

protected:
runtime_tensor input;
};
Expand Down Expand Up @@ -77,6 +223,8 @@ TEST_P(UnaryTest, abs) {
cosine_similarity_tensor(expected, actual);

if (!result) {
std::cout << "input ";
print_runtime_tensor(input);
std::cout << "actual ";
print_runtime_tensor(actual);
std::cout << "expected ";
Expand Down

0 comments on commit e968882

Please sign in to comment.