Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support align_corners for Resize operator #418

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions builtin_op_importers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2445,11 +2445,11 @@ DEFINE_BUILTIN_OP_IMPORTER(Resize)

auto mode = attrs.get<std::string>("mode", "nearest");
auto resizeMode = mode == "nearest" ? nvinfer1::ResizeMode::kNEAREST : nvinfer1::ResizeMode::kLINEAR;
auto transformationMode = attrs.get<std::string>("coordinate_transformation_mode", "half_pixel");

if (ctx->getOpsetVersion() >= 11)
{
auto transformationMode = attrs.get<std::string>("coordinate_transformation_mode", "half_pixel");
ASSERT((transformationMode == "asymmetric") && "This version of TensorRT only supports asymmetric resize!",
ASSERT(((transformationMode == "asymmetric") || (transformationMode == "align_corners")) && "This version of TensorRT only supports asymmetric and align_corners resize!",
ErrorCode::kUNSUPPORTED_NODE);
ASSERT(mode != "cubic" && "This version of TensorRT does not support cubic interpolation!",
ErrorCode::kUNSUPPORTED_NODE);
Expand All @@ -2464,11 +2464,14 @@ DEFINE_BUILTIN_OP_IMPORTER(Resize)
auto* resizeShape = &convertToTensor(inputs.at(3), ctx);
layer->setInput(1, *resizeShape);
layer->setResizeMode(resizeMode);
if (transformationMode=="align_corners")
layer->setAlignCorners(true);
RETURN_FIRST_OUTPUT(layer);
}
}

// Resizes that use scale factors have the same import logic between opsets
ASSERT(transformationMode != "align_corners" && "Align_corners should use size information not scale factors!", ErrorCode::kUNSUPPORTED_NODE);
auto scales = ctx->getOpsetVersion() >= 11 ? inputs.at(2) : inputs.at(1);
ASSERT(scales.is_weights() && "Resize scales must be an initializer!", ErrorCode::kUNSUPPORTED_NODE);
ShapedWeights scales_weights = scales.weights();
Expand Down