From 56e32fdf63bdd0f999f7d5ae82e5314ac7491079 Mon Sep 17 00:00:00 2001 From: lzhangzz Date: Tue, 21 Dec 2021 20:16:40 +0800 Subject: [PATCH] improve shape checking (#315) --- csrc/codebase/mmcls/linear_cls.cpp | 12 +++++---- csrc/codebase/mmdet/instance_segmentation.cpp | 25 +++++++++++++------ csrc/codebase/mmdet/object_detection.cpp | 15 ++++++----- csrc/codebase/mmedit/restorer.cpp | 4 ++- csrc/codebase/mmocr/crnn.cpp | 4 ++- csrc/codebase/mmocr/dbnet.cpp | 6 +++++ csrc/codebase/mmseg/segment.cpp | 18 +++++++------ 7 files changed, 56 insertions(+), 28 deletions(-) diff --git a/csrc/codebase/mmcls/linear_cls.cpp b/csrc/codebase/mmcls/linear_cls.cpp index 62ac58fa0..07704cd08 100644 --- a/csrc/codebase/mmcls/linear_cls.cpp +++ b/csrc/codebase/mmcls/linear_cls.cpp @@ -26,15 +26,17 @@ class LinearClsHead : public MMClassification { Result operator()(const Value& infer_res) { DEBUG("infer_res: {}", infer_res); - auto output_tensor = infer_res["output"].get(); - assert(output_tensor.shape().size() >= 2); - auto class_num = (int)output_tensor.shape()[1]; + auto output = infer_res["output"].get(); - if (output_tensor.data_type() != DataType::kFLOAT) { + if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) { + ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(), + (int)output.data_type()); return Status(eNotSupported); } - OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output_tensor, kHost, stream())); + auto class_num = (int)output.shape(1); + + OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream())); OUTCOME_TRY(stream().Wait()); return GetLabels(_scores, class_num); diff --git a/csrc/codebase/mmdet/instance_segmentation.cpp b/csrc/codebase/mmdet/instance_segmentation.cpp index ae0dfd4f6..481e0b1e2 100644 --- a/csrc/codebase/mmdet/instance_segmentation.cpp +++ b/csrc/codebase/mmdet/instance_segmentation.cpp @@ -17,12 +17,10 @@ class ResizeInstanceMask : public ResizeBBox { } } + // TODO: remove duplication Result operator()(const Value& prep_res, const Value& infer_res) { DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res); try { - assert(prep_res.contains("img_metas")); - // Value res = prep_res; - auto dets = infer_res["dets"].get(); auto labels = infer_res["labels"].get(); auto masks = infer_res["masks"].get(); @@ -33,14 +31,25 @@ class ResizeInstanceMask : public ResizeBBox { // `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number' // and 'channels' respectively - assert(dets.shape().size() == 3); - assert(dets.data_type() == DataType::kFLOAT); - - assert(masks.data_type() == DataType::kFLOAT); + if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) { + ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(), + (int)dets.data_type()); + return Status(eNotSupported); + } // `labels` is supposed to have 2 dims, which are 'batch' and // 'bboxes_number' - assert(labels.shape().size() == 2); + if (labels.shape().size() != 2) { + ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(), + (int)labels.data_type()); + return Status(eNotSupported); + } + + if (!(masks.shape().size() == 4 && masks.data_type() == DataType::kFLOAT)) { + ERROR("unsupported `mask` tensor, shape: {}, dtype: {}", masks.shape(), + (int)masks.data_type()); + return Status(eNotSupported); + } OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream())); OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream())); diff --git a/csrc/codebase/mmdet/object_detection.cpp b/csrc/codebase/mmdet/object_detection.cpp index 519cac561..a839cf7e6 100644 --- a/csrc/codebase/mmdet/object_detection.cpp +++ b/csrc/codebase/mmdet/object_detection.cpp @@ -19,9 +19,6 @@ ResizeBBox::ResizeBBox(const Value& cfg) : MMDetection(cfg) { Result ResizeBBox::operator()(const Value& prep_res, const Value& infer_res) { DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res); try { - assert(prep_res.contains("img_metas")); - // Value res = prep_res; - auto dets = infer_res["dets"].get(); auto labels = infer_res["labels"].get(); @@ -30,12 +27,18 @@ Result ResizeBBox::operator()(const Value& prep_res, const Value& infer_r // `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number' // and 'channels' respectively - assert(dets.shape().size() == 3); - assert(dets.data_type() == DataType::kFLOAT); + if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) { + ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(), (int)dets.data_type()); + return Status(eNotSupported); + } // `labels` is supposed to have 2 dims, which are 'batch' and // 'bboxes_number' - assert(labels.shape().size() == 2); + if (labels.shape().size() != 2) { + ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(), + (int)labels.data_type()); + return Status(eNotSupported); + } OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream())); OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream())); diff --git a/csrc/codebase/mmedit/restorer.cpp b/csrc/codebase/mmedit/restorer.cpp index 36ed51a8e..da06075a4 100644 --- a/csrc/codebase/mmedit/restorer.cpp +++ b/csrc/codebase/mmedit/restorer.cpp @@ -16,7 +16,7 @@ class TensorToImg : public MMEdit { auto upscale = input["output"].get(); OUTCOME_TRY(auto upscale_cpu, MakeAvailableOnDevice(upscale, kHOST, stream())); OUTCOME_TRY(stream().Wait()); - if (upscale.data_type() == DataType::kFLOAT) { + if (upscale.shape().size() == 4 && upscale.data_type() == DataType::kFLOAT) { auto channels = static_cast(upscale.shape(1)); auto height = static_cast(upscale.shape(2)); auto width = static_cast(upscale.shape(3)); @@ -32,6 +32,8 @@ class TensorToImg : public MMEdit { mat_hwc.convertTo(rescale_uint8, CV_8UC(channels), 255.f); return mat; } else { + ERROR("unsupported `output` tensor, shape: {}, dtype: {}", upscale.shape(), + (int)upscale.data_type()); return Status(eNotSupported); } } diff --git a/csrc/codebase/mmocr/crnn.cpp b/csrc/codebase/mmocr/crnn.cpp index 7b7a6f265..bd6c4a617 100644 --- a/csrc/codebase/mmocr/crnn.cpp +++ b/csrc/codebase/mmocr/crnn.cpp @@ -61,7 +61,9 @@ class CTCConvertor : public MMOCR { Result operator()(const Value& _data, const Value& _prob) { auto d_conf = _prob["output"].get(); - if (d_conf.data_type() != DataType::kFLOAT) { + if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) { + ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(), + (int)d_conf.data_type()); return Status(eNotSupported); } diff --git a/csrc/codebase/mmocr/dbnet.cpp b/csrc/codebase/mmocr/dbnet.cpp index e4abf577c..93a3d0400 100644 --- a/csrc/codebase/mmocr/dbnet.cpp +++ b/csrc/codebase/mmocr/dbnet.cpp @@ -49,6 +49,12 @@ class DBHead : public MMOCR { OUTCOME_TRY(stream_.Wait()); DEBUG("shape: {}", conf.shape()); + if (!(conf.shape().size() == 4 && conf.data_type() == DataType::kFLOAT)) { + ERROR("unsupported `output` tensor, shape: {}, dtype: {}", conf.shape(), + (int)conf.data_type()); + return Status(eNotSupported); + } + auto h = conf.shape(2); auto w = conf.shape(3); auto data = conf.buffer().GetNative(); diff --git a/csrc/codebase/mmseg/segment.cpp b/csrc/codebase/mmseg/segment.cpp index eced3098b..48afa9b57 100644 --- a/csrc/codebase/mmseg/segment.cpp +++ b/csrc/codebase/mmseg/segment.cpp @@ -24,11 +24,12 @@ class ResizeMask : public MMSegmentation { DEBUG("preprocess: {}\ninference: {}", preprocess_result, inference_result); auto mask = inference_result["output"].get(); - INFO("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(), - mask.data_type()); - assert(mask.data_type() == DataType::kINT32 || mask.data_type() == DataType::kINT64); - assert(mask.shape(0) == 1); - assert(mask.shape(1) == 1); + DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(), + mask.data_type()); + if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) { + ERROR("unsupported `output` tensor, shape: {}", mask.shape()); + return Status(eNotSupported); + } auto height = (int)mask.shape(2); auto width = (int)mask.shape(3); @@ -36,7 +37,7 @@ class ResizeMask : public MMSegmentation { auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get(); Device host{"cpu"}; OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_)); - stream_.Wait().value(); + OUTCOME_TRY(stream_.Wait()); if (mask.data_type() == DataType::kINT64) { // change kINT64 to 2 INT32 TensorDesc desc{.device = host_tensor.device(), @@ -45,8 +46,11 @@ class ResizeMask : public MMSegmentation { .name = host_tensor.name()}; Tensor _host_tensor(desc, mask.buffer()); return MaskResize(_host_tensor, input_height, input_width); - } else { + } else if (mask.data_type() == DataType::kINT32) { return MaskResize(host_tensor, input_height, input_width); + } else { + ERROR("unsupported `output` tensor, dtype: {}", (int)mask.data_type()); + return Status(eNotSupported); } }