Skip to content

Commit

Permalink
Add support for Upsample with opset_ver >= 9 (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yinghai Lu committed Nov 28, 2018
1 parent a9a47bb commit e08f047
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions builtin_op_importers.cpp
Expand Up @@ -1816,25 +1816,41 @@ DEFINE_BUILTIN_OP_IMPORTER(Unsqueeze) {

DEFINE_BUILTIN_OP_IMPORTER(Upsample) {
ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE);
nvinfer1::ITensor& tensor = inputs.at(0).tensor();
nvinfer1::ITensor &tensor = inputs.at(0).tensor();
ASSERT(tensor.getDimensions().nbDims == 3, ErrorCode::kUNSUPPORTED_NODE);
OnnxAttrs attrs(node);
float height_scale, width_scale;
if( !attrs.count("scales") ) {
height_scale = attrs.get<float>("height_scale");
width_scale = attrs.get<float>("width_scale");
if (ctx->getOpsetVersion() >= 9) {
ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE);
auto scales_input = inputs.at(1);
ASSERT(scales_input.is_weights(), ErrorCode::kUNSUPPORTED_NODE);
ShapedWeights scales_weights = scales_input.weights();
ASSERT(scales_weights.shape.nbDims == 1, ErrorCode::kUNSUPPORTED_NODE);
ASSERT(scales_weights.count() == 4, ErrorCode::kUNSUPPORTED_NODE);
ASSERT(scales_weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT,
ErrorCode::kINVALID_NODE);
float const *scales_ptr = static_cast<float const *>(scales_weights.values);
ASSERT(scales_ptr[0] == 1 && scales_ptr[1] == 1,
ErrorCode::kUNSUPPORTED_NODE);
height_scale = scales_ptr[2];
width_scale = scales_ptr[3];
} else {
auto scales = attrs.get<std::vector<float>>("scales");
ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE);
ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE);
height_scale = scales[2];
width_scale = scales[3];
if (!attrs.count("scales")) {
height_scale = attrs.get<float>("height_scale");
width_scale = attrs.get<float>("width_scale");
} else {
auto scales = attrs.get<std::vector<float>>("scales");
ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE);
ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE);
height_scale = scales[2];
width_scale = scales[3];
}
}
auto scale = {height_scale, width_scale};
auto mode = attrs.get<std::string>("mode", "nearest");
ASSERT(mode == "nearest", ErrorCode::kUNSUPPORTED_NODE);
RETURN_FIRST_OUTPUT(ctx->addPlugin(new ResizeNearestPlugin(scale),
{&inputs.at(0).tensor()}));
RETURN_FIRST_OUTPUT(
ctx->addPlugin(new ResizeNearestPlugin(scale), {&inputs.at(0).tensor()}));
}

} // namespace
Expand Down

0 comments on commit e08f047

Please sign in to comment.