Skip to content

Commit

Permalink
improve shape checking (open-mmlab#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Dec 21, 2021
1 parent ce2b778 commit 56e32fd
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 28 deletions.
12 changes: 7 additions & 5 deletions csrc/codebase/mmcls/linear_cls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ class LinearClsHead : public MMClassification {

Result<Value> operator()(const Value& infer_res) {
DEBUG("infer_res: {}", infer_res);
auto output_tensor = infer_res["output"].get<Tensor>();
assert(output_tensor.shape().size() >= 2);
auto class_num = (int)output_tensor.shape()[1];
auto output = infer_res["output"].get<Tensor>();

if (output_tensor.data_type() != DataType::kFLOAT) {
if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(),
(int)output.data_type());
return Status(eNotSupported);
}

OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output_tensor, kHost, stream()));
auto class_num = (int)output.shape(1);

OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream()));
OUTCOME_TRY(stream().Wait());

return GetLabels(_scores, class_num);
Expand Down
25 changes: 17 additions & 8 deletions csrc/codebase/mmdet/instance_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ class ResizeInstanceMask : public ResizeBBox {
}
}

// TODO: remove duplication
Result<Value> operator()(const Value& prep_res, const Value& infer_res) {
DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
try {
assert(prep_res.contains("img_metas"));
// Value res = prep_res;

auto dets = infer_res["dets"].get<Tensor>();
auto labels = infer_res["labels"].get<Tensor>();
auto masks = infer_res["masks"].get<Tensor>();
Expand All @@ -33,14 +31,25 @@ class ResizeInstanceMask : public ResizeBBox {

// `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number'
// and 'channels' respectively
assert(dets.shape().size() == 3);
assert(dets.data_type() == DataType::kFLOAT);

assert(masks.data_type() == DataType::kFLOAT);
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(),
(int)dets.data_type());
return Status(eNotSupported);
}

// `labels` is supposed to have 2 dims, which are 'batch' and
// 'bboxes_number'
assert(labels.shape().size() == 2);
if (labels.shape().size() != 2) {
ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
(int)labels.data_type());
return Status(eNotSupported);
}

if (!(masks.shape().size() == 4 && masks.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `mask` tensor, shape: {}, dtype: {}", masks.shape(),
(int)masks.data_type());
return Status(eNotSupported);
}

OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));
Expand Down
15 changes: 9 additions & 6 deletions csrc/codebase/mmdet/object_detection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ ResizeBBox::ResizeBBox(const Value& cfg) : MMDetection(cfg) {
Result<Value> ResizeBBox::operator()(const Value& prep_res, const Value& infer_res) {
DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
try {
assert(prep_res.contains("img_metas"));
// Value res = prep_res;

auto dets = infer_res["dets"].get<Tensor>();
auto labels = infer_res["labels"].get<Tensor>();

Expand All @@ -30,12 +27,18 @@ Result<Value> ResizeBBox::operator()(const Value& prep_res, const Value& infer_r

// `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number'
// and 'channels' respectively
assert(dets.shape().size() == 3);
assert(dets.data_type() == DataType::kFLOAT);
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(), (int)dets.data_type());
return Status(eNotSupported);
}

// `labels` is supposed to have 2 dims, which are 'batch' and
// 'bboxes_number'
assert(labels.shape().size() == 2);
if (labels.shape().size() != 2) {
ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
(int)labels.data_type());
return Status(eNotSupported);
}

OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));
Expand Down
4 changes: 3 additions & 1 deletion csrc/codebase/mmedit/restorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TensorToImg : public MMEdit {
auto upscale = input["output"].get<Tensor>();
OUTCOME_TRY(auto upscale_cpu, MakeAvailableOnDevice(upscale, kHOST, stream()));
OUTCOME_TRY(stream().Wait());
if (upscale.data_type() == DataType::kFLOAT) {
if (upscale.shape().size() == 4 && upscale.data_type() == DataType::kFLOAT) {
auto channels = static_cast<int>(upscale.shape(1));
auto height = static_cast<int>(upscale.shape(2));
auto width = static_cast<int>(upscale.shape(3));
Expand All @@ -32,6 +32,8 @@ class TensorToImg : public MMEdit {
mat_hwc.convertTo(rescale_uint8, CV_8UC(channels), 255.f);
return mat;
} else {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", upscale.shape(),
(int)upscale.data_type());
return Status(eNotSupported);
}
}
Expand Down
4 changes: 3 additions & 1 deletion csrc/codebase/mmocr/crnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class CTCConvertor : public MMOCR {
Result<Value> operator()(const Value& _data, const Value& _prob) {
auto d_conf = _prob["output"].get<Tensor>();

if (d_conf.data_type() != DataType::kFLOAT) {
if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(),
(int)d_conf.data_type());
return Status(eNotSupported);
}

Expand Down
6 changes: 6 additions & 0 deletions csrc/codebase/mmocr/dbnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class DBHead : public MMOCR {
OUTCOME_TRY(stream_.Wait());
DEBUG("shape: {}", conf.shape());

if (!(conf.shape().size() == 4 && conf.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", conf.shape(),
(int)conf.data_type());
return Status(eNotSupported);
}

auto h = conf.shape(2);
auto w = conf.shape(3);
auto data = conf.buffer().GetNative();
Expand Down
18 changes: 11 additions & 7 deletions csrc/codebase/mmseg/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@ class ResizeMask : public MMSegmentation {
DEBUG("preprocess: {}\ninference: {}", preprocess_result, inference_result);

auto mask = inference_result["output"].get<Tensor>();
INFO("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(),
mask.data_type());
assert(mask.data_type() == DataType::kINT32 || mask.data_type() == DataType::kINT64);
assert(mask.shape(0) == 1);
assert(mask.shape(1) == 1);
DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(),
mask.data_type());
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
ERROR("unsupported `output` tensor, shape: {}", mask.shape());
return Status(eNotSupported);
}

auto height = (int)mask.shape(2);
auto width = (int)mask.shape(3);
auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get<int>();
auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get<int>();
Device host{"cpu"};
OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_));
stream_.Wait().value();
OUTCOME_TRY(stream_.Wait());
if (mask.data_type() == DataType::kINT64) {
// change kINT64 to 2 INT32
TensorDesc desc{.device = host_tensor.device(),
Expand All @@ -45,8 +46,11 @@ class ResizeMask : public MMSegmentation {
.name = host_tensor.name()};
Tensor _host_tensor(desc, mask.buffer());
return MaskResize(_host_tensor, input_height, input_width);
} else {
} else if (mask.data_type() == DataType::kINT32) {
return MaskResize(host_tensor, input_height, input_width);
} else {
ERROR("unsupported `output` tensor, dtype: {}", (int)mask.data_type());
return Status(eNotSupported);
}
}

Expand Down

0 comments on commit 56e32fd

Please sign in to comment.