Skip to content

Commit

Permalink
Check tensor type and dimensions if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 16, 2023
1 parent 9b90f6c commit fb86689
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions onnx/defs/traditionalml/defs.cc
Expand Up @@ -391,6 +391,21 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
" must be the same in the LabelEncoder.");
}

auto default_attr = ctx.getAttribute("default_tensor");
if (nullptr != default_attr) {
auto default_tensor = default_attr->t();
if (default_tensor.data_type() != value_type) {
fail_shape_inference(
"The default tensor type ",
default_tensor.data_type(),
" and the value type ",
value_type,
" must be the same in the LabelEncoder.");
}
if (1 != default_tensor.dims_size()) {
fail_shape_inference("The default tensor must be a 1D tensor.");
}
}
// Propagate shape from input type and assign output type based on value type
ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type(value_type);
propagateShapeFromInputToOutput(ctx, 0, 0);
Expand Down

0 comments on commit fb86689

Please sign in to comment.