Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/remove-test-globals
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Oct 17, 2023
2 parents 2aeba4d + 5f908a9 commit e4677c0
Show file tree
Hide file tree
Showing 13 changed files with 22 additions and 803 deletions.
2 changes: 1 addition & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -23094,7 +23094,7 @@ expect(
node,
inputs=[data, axes],
outputs=[reduced],
name="test_reduce_sum_empty_set",
name="test_reduce_sum_empty_set_non_reduced_axis_zero",
)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -15681,7 +15681,7 @@ expect(
node,
inputs=[data, axes],
outputs=[reduced],
name="test_reduce_sum_empty_set",
name="test_reduce_sum_empty_set_non_reduced_axis_zero",
)
```

Expand Down
2 changes: 1 addition & 1 deletion onnx/backend/test/case/node/reducesum.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,5 @@ def export_non_reduced_axis_zero() -> None:
node,
inputs=[data, axes],
outputs=[reduced],
name="test_reduce_sum_empty_set",
name="test_reduce_sum_empty_set_non_reduced_axis_zero",
)
Binary file modified onnx/backend/test/data/node/test_reduce_sum_empty_set/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 3 additions & 3 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) {
}
std::string data_path = path_join(ctx.get_model_dir(), relative_path);
// use stat64 to check whether the file exists
#if defined(__APPLE__) || defined(__wasm__)
struct stat buffer; // APPLE does not have stat64
#if defined(__APPLE__) || defined(__wasm__) || !defined(__GLIBC__)
struct stat buffer; // APPLE, wasm and non-glic stdlibs do not have stat64
if (stat((data_path).c_str(), &buffer) != 0) {
#else
struct stat64 buffer; // All POSIX except APPLE have stat64
struct stat64 buffer; // All POSIX under glibc except APPLE and wasm have stat64
if (stat64((data_path).c_str(), &buffer) != 0) {
#endif
fail_check(
Expand Down
38 changes: 16 additions & 22 deletions onnx/defs/traditionalml/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,28 +382,6 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
fail_shape_inference(
"At least one of values_tensor, values_strings, values_int64s, values_floats must be set.");
}

int default_length, default_type;
std::tie(default_type, default_length) = getAttributeElementTypeAndLength(
ctx, {"default_tensor", "default_string", "default_int64", "default_float"});
if (default_type != TensorProto::UNDEFINED) {
if (value_type != default_type) {
fail_shape_inference(
"The value type ",
value_type,
" and the default type ",
default_type,
" are different, which is not permitted for LabelEncoders.");
}

// Ensure default_tensor is a singleton if set
const AttributeProto* default_tensor = ctx.getAttribute("default_tensor");
if (default_tensor != nullptr &&
(default_tensor->t().dims_size() != 1 || default_tensor->t().dims(0) != 1)) {
fail_shape_inference("default_tensor must be a singleton if set.");
}
}

if (value_length != key_length) {
fail_shape_inference(
"The number of keys ",
Expand All @@ -413,6 +391,22 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
" must be the same in the LabelEncoder.");
}

auto default_attr = ctx.getAttribute("default_tensor");
if (nullptr != default_attr && default_attr->has_t() && default_attr->t().has_data_type() &&
default_attr->t().data_type() != TensorProto_DataType_UNDEFINED) {
auto default_tensor = default_attr->t();
if (default_tensor.data_type() != value_type) {
fail_shape_inference(
"The default tensor type ",
default_tensor.data_type(),
" and the value type ",
value_type,
" must be the same in the LabelEncoder.");
}
if (1 != default_tensor.dims_size() || 1 != default_tensor.dims(0)) {
fail_shape_inference("The default tensor must be a singleton 1D tensor.");
}
}
// Propagate shape from input type and assign output type based on value type
ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type(value_type);
propagateShapeFromInputToOutput(ctx, 0, 0);
Expand Down
Loading

0 comments on commit e4677c0

Please sign in to comment.