Skip to content

Commit

Permalink
Update error messages for NegativeLogLikelihoodLoss inference function (
Browse files Browse the repository at this point in the history
#6021)

### Description
Add the invalid shape information in errors message of
NegativeLogLikelihoodLoss's inference function.

### Motivation and Context
Better errors

Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Mar 18, 2024
1 parent 9cc907f commit 17dbae7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
19 changes: 14 additions & 5 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2519,10 +2519,14 @@ ONNX_OPERATOR_SET_SCHEMA(
const int target_rank = static_cast<int>(target_shape.dim_size());

if (input_rank < 2) {
fail_shape_inference("Input rank must be >= 2.")
fail_shape_inference("Input rank must be >= 2. input_rank=", input_rank);
}
if (target_rank != input_rank - 1) {
fail_shape_inference("Target rank must be 1 less than the input rank.");
fail_shape_inference(
"Target rank must be 1 less than the input rank. input_rank=",
input_rank,
", target_rank=",
target_rank);
}

// match input dimensions (N, C, d1, ..., dk) with target
Expand All @@ -2532,13 +2536,18 @@ ONNX_OPERATOR_SET_SCHEMA(
const auto target_dim = target_shape.dim(dim);
if (input_dim.has_dim_value() && target_dim.has_dim_value() &&
input_dim.dim_value() != target_dim.dim_value())
fail_shape_inference("Input and target dimension value mismatch.");
fail_shape_inference(
"Input and target dimension value mismatch. input_dim_value=",
input_dim.dim_value(),
" target_dim_value=",
target_dim.dim_value());
}

if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) {
const TensorShapeProto& weight_shape = ctx.getInputType(2)->tensor_type().shape();
if (weight_shape.dim_size() != 1) {
fail_shape_inference("Weight rank must be 1.");
const auto weight_rank = weight_shape.dim_size();
if (weight_rank != 1) {
fail_shape_inference("Weight rank must be 1. weight_rank=", weight_rank);
}
}

Expand Down
19 changes: 14 additions & 5 deletions onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1294,10 +1294,14 @@ ONNX_OPERATOR_SET_SCHEMA(
const int target_rank = static_cast<int>(target_shape.dim_size());

if (input_rank < 2) {
fail_shape_inference("Input rank must be >= 2.");
fail_shape_inference("Input rank must be >= 2. input_rank=", input_rank);
}
if (target_rank != input_rank - 1) {
fail_shape_inference("Target rank must be 1 less than the input rank.");
fail_shape_inference(
"Target rank must be 1 less than the input rank. input_rank=",
input_rank,
", target_rank=",
target_rank);
}

// match input dimensions (N, C, d1, ..., dk) with target
Expand All @@ -1307,13 +1311,18 @@ ONNX_OPERATOR_SET_SCHEMA(
const auto target_dim = target_shape.dim(dim);
if (input_dim.has_dim_value() && target_dim.has_dim_value() &&
input_dim.dim_value() != target_dim.dim_value())
fail_shape_inference("Input and target dimension value mismatch.");
fail_shape_inference(
"Input and target dimension value mismatch. input_dim_value=",
input_dim.dim_value(),
" target_dim_value=",
target_dim.dim_value());
}

if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) {
const TensorShapeProto& weight_shape = ctx.getInputType(2)->tensor_type().shape();
if (weight_shape.dim_size() != 1) {
fail_shape_inference("Weight rank must be 1.");
const auto weight_rank = weight_shape.dim_size();
if (weight_rank != 1) {
fail_shape_inference("Weight rank must be 1. weight_rank=", weight_rank);
}
}

Expand Down

0 comments on commit 17dbae7

Please sign in to comment.