diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index ef134ffa..2ca1ac3d 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -1816,25 +1816,41 @@ DEFINE_BUILTIN_OP_IMPORTER(Unsqueeze) { DEFINE_BUILTIN_OP_IMPORTER(Upsample) { ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE); - nvinfer1::ITensor& tensor = inputs.at(0).tensor(); + nvinfer1::ITensor &tensor = inputs.at(0).tensor(); ASSERT(tensor.getDimensions().nbDims == 3, ErrorCode::kUNSUPPORTED_NODE); OnnxAttrs attrs(node); float height_scale, width_scale; - if( !attrs.count("scales") ) { - height_scale = attrs.get("height_scale"); - width_scale = attrs.get("width_scale"); + if (ctx->getOpsetVersion() >= 9) { + ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE); + auto scales_input = inputs.at(1); + ASSERT(scales_input.is_weights(), ErrorCode::kUNSUPPORTED_NODE); + ShapedWeights scales_weights = scales_input.weights(); + ASSERT(scales_weights.shape.nbDims == 1, ErrorCode::kUNSUPPORTED_NODE); + ASSERT(scales_weights.count() == 4, ErrorCode::kUNSUPPORTED_NODE); + ASSERT(scales_weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT, + ErrorCode::kINVALID_NODE); + float const *scales_ptr = static_cast(scales_weights.values); + ASSERT(scales_ptr[0] == 1 && scales_ptr[1] == 1, + ErrorCode::kUNSUPPORTED_NODE); + height_scale = scales_ptr[2]; + width_scale = scales_ptr[3]; } else { - auto scales = attrs.get>("scales"); - ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE); - ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE); - height_scale = scales[2]; - width_scale = scales[3]; + if (!attrs.count("scales")) { + height_scale = attrs.get("height_scale"); + width_scale = attrs.get("width_scale"); + } else { + auto scales = attrs.get>("scales"); + ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE); + ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE); + height_scale = scales[2]; + width_scale = scales[3]; + } } auto scale = {height_scale, width_scale}; auto mode = attrs.get("mode", "nearest"); ASSERT(mode == "nearest", ErrorCode::kUNSUPPORTED_NODE); - RETURN_FIRST_OUTPUT(ctx->addPlugin(new ResizeNearestPlugin(scale), - {&inputs.at(0).tensor()})); + RETURN_FIRST_OUTPUT( + ctx->addPlugin(new ResizeNearestPlugin(scale), {&inputs.at(0).tensor()})); } } // namespace