Skip to content

Commit

Permalink
[Fix]: fix rtmdet-inst sdk (#2343)
Browse files Browse the repository at this point in the history
* support rtmdet-inst sdk

* fix batch infer

* fix

* fix mask resize

* fix

* update

* fix segment fault

* fix

* fix lint

* fix

* fix

* fix

* resolve comments
  • Loading branch information
RunningLeon committed Sep 4, 2023
1 parent 5479c87 commit 123b9fc
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 227 deletions.
101 changes: 63 additions & 38 deletions csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ResizeInstanceMask : public ResizeBBox {
}
operation::Context ctx(device_, stream_);
warp_affine_ = operation::Managed<operation::WarpAffine>::Create("bilinear");
permute_ = operation::Managed<::mmdeploy::operation::Permute>::Create();
}

// TODO: remove duplication
Expand Down Expand Up @@ -65,9 +66,9 @@ class ResizeInstanceMask : public ResizeBBox {
// OUTCOME_TRY(stream().Wait());

OUTCOME_TRY(auto result, DispatchGetBBoxes(prep_res["img_metas"], _dets, _labels));

auto ori_w = prep_res["img_metas"]["ori_shape"][2].get<int>();
auto ori_h = prep_res["img_metas"]["ori_shape"][1].get<int>();
from_value(prep_res["img_metas"]["scale_factor"], scale_factor_);

ProcessMasks(result, masks, _dets, ori_w, ori_h);

Expand All @@ -92,49 +93,71 @@ class ResizeInstanceMask : public ResizeBBox {
std::vector<Tensor> h_warped_masks;
h_warped_masks.reserve(result.size());

for (auto& det : result) {
auto mask = d_mask.Slice(det.index);
auto mask_height = (int)mask.shape(1);
auto mask_width = (int)mask.shape(2);
mask.Reshape({1, mask_height, mask_width, 1});
if (is_resize_mask_) {
auto& bbox = det.bbox;
// same as mmdet with skip_empty = True
auto x0 = std::max(std::floor(bbox[0]) - 1, 0.f);
auto y0 = std::max(std::floor(bbox[1]) - 1, 0.f);
auto x1 = std::min(std::ceil(bbox[2]) + 1, (float)img_w);
auto y1 = std::min(std::ceil(bbox[3]) + 1, (float)img_h);
auto width = static_cast<int>(x1 - x0);
auto height = static_cast<int>(y1 - y0);
// params align_corners = False
float fx;
float fy;
float tx;
float ty;
if (is_rcnn_) { // mask r-cnn
if (is_rcnn_) { // mask r-cnn
for (auto& det : result) {
auto mask = d_mask.Slice(det.index);
auto mask_height = (int)mask.shape(1);
auto mask_width = (int)mask.shape(2);
mask.Reshape({1, mask_height, mask_width, 1});
// resize masks to origin image shape instead of input image shape
// default is true
if (is_resize_mask_) {
auto& bbox = det.bbox;
// same as mmdet with skip_empty = True
auto x0 = std::max(std::floor(bbox[0]) - 1, 0.f);
auto y0 = std::max(std::floor(bbox[1]) - 1, 0.f);
auto x1 = std::min(std::ceil(bbox[2]) + 1, (float)img_w);
auto y1 = std::min(std::ceil(bbox[3]) + 1, (float)img_h);
auto width = static_cast<int>(x1 - x0);
auto height = static_cast<int>(y1 - y0);
// params align_corners = False
float fx;
float fy;
float tx;
float ty;
fx = (float)mask_width / (bbox[2] - bbox[0]);
fy = (float)mask_height / (bbox[3] - bbox[1]);
tx = (x0 + .5f - bbox[0]) * fx - .5f;
ty = (y0 + .5f - bbox[1]) * fy - .5f;
} else { // rtmdet-ins
auto raw_bbox = cpu_dets.Slice(det.index);
auto raw_bbox_data = raw_bbox.data<float>();
fx = (raw_bbox_data[2] - raw_bbox_data[0]) / (bbox[2] - bbox[0]);
fy = (raw_bbox_data[3] - raw_bbox_data[1]) / (bbox[3] - bbox[1]);
tx = (x0 + .5f - bbox[0]) * fx - .5f + raw_bbox_data[0];
ty = (y0 + .5f - bbox[1]) * fy - .5f + raw_bbox_data[1];
}

float affine_matrix[] = {fx, 0, tx, 0, fy, ty};
float affine_matrix[] = {fx, 0, tx, 0, fy, ty};

cv::Mat_<float> m(2, 3, affine_matrix);
cv::invertAffineTransform(m, m);
Tensor& warped_mask = warped_masks.emplace_back();
OUTCOME_TRY(warp_affine_.Apply(mask, warped_mask, affine_matrix, height, width));
OUTCOME_TRY(CopyToHost(warped_mask, h_warped_masks.emplace_back()));

cv::Mat_<float> m(2, 3, affine_matrix);
cv::invertAffineTransform(m, m);
Tensor& warped_mask = warped_masks.emplace_back();
OUTCOME_TRY(warp_affine_.Apply(mask, warped_mask, affine_matrix, height, width));
OUTCOME_TRY(CopyToHost(warped_mask, h_warped_masks.emplace_back()));
} else {
OUTCOME_TRY(CopyToHost(mask, h_warped_masks.emplace_back()));
}
}

} else {
OUTCOME_TRY(CopyToHost(mask, h_warped_masks.emplace_back()));
} else { // rtmdet-inst
auto mask_channel = (int)d_mask.shape(0);
auto mask_height = (int)d_mask.shape(1);
auto mask_width = (int)d_mask.shape(2);
// (C, H, W) -> (H, W, C)
std::vector<int> axes = {1, 2, 0};
OUTCOME_TRY(permute_.Apply(d_mask, d_mask, axes));
Device host{"cpu"};
OUTCOME_TRY(auto cpu_mask, MakeAvailableOnDevice(d_mask, host, stream_));
OUTCOME_TRY(stream().Wait());
cv::Mat mask_mat(mask_height, mask_width, CV_32FC(mask_channel), cpu_mask.data());
int resize_height = int(mask_height / scale_factor_[0] + 0.5);
int resize_width = int(mask_width / scale_factor_[1] + 0.5);
// skip resize if scale_factor is 1.0
if (resize_height != mask_height || resize_width != mask_width) {
cv::resize(mask_mat, mask_mat, cv::Size(resize_height, resize_width), cv::INTER_LINEAR);
}
// crop masks
mask_mat = mask_mat(cv::Range(0, img_h), cv::Range(0, img_w)).clone();

for (int i = 0; i < (int)result.size(); i++) {
cv::Mat mask_;
cv::extractChannel(mask_mat, mask_, i);
Tensor mask_t = cpu::CVMat2Tensor(mask_);
h_warped_masks.emplace_back(mask_t);
}
}

Expand Down Expand Up @@ -166,9 +189,11 @@ class ResizeInstanceMask : public ResizeBBox {

private:
operation::Managed<operation::WarpAffine> warp_affine_;
::mmdeploy::operation::Managed<::mmdeploy::operation::Permute> permute_;
float mask_thr_binary_{.5f};
bool is_rcnn_{true};
bool is_resize_mask_{false};
bool is_resize_mask_{true};
std::vector<float> scale_factor_{1.0, 1.0, 1.0, 1.0};
};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, ResizeInstanceMask);
Expand Down
25 changes: 15 additions & 10 deletions demo/csrc/c/object_detection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int main(int argc, char* argv[]) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
}

cv::Size img_size = img.size();
mmdeploy_detector_t detector{};
int status{};
status = mmdeploy_detector_create_by_path(model_path, device_name, 0, &detector);
Expand Down Expand Up @@ -60,17 +60,22 @@ int main(int argc, char* argv[]) {
// generate mask overlay if model exports masks
if (mask != nullptr) {
fprintf(stdout, "mask %d, height=%d, width=%d\n", i, mask->height, mask->width);

cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]);
auto x0 = std::max(std::floor(box.left) - 1, 0.f);
auto y0 = std::max(std::floor(box.top) - 1, 0.f);
cv::Rect roi((int)x0, (int)y0, mask->width, mask->height);

// split the RGB channels, overlay mask to a specific color channel
cv::Mat ch[3];
split(img, ch);
cv::Mat ch[3], mask_img;
int col = 0; // int col = i % 3;
cv::bitwise_or(imgMask, ch[col](roi), ch[col](roi));
split(img, ch);
cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]);
// rtmdet-inst
if (img_size.height == mask->height && img_size.width == mask->width) {
mask_img = ch[col];
}
else {
auto x0 = std::max(std::floor(box.left) - 1, 0.f);
auto y0 = std::max(std::floor(box.top) - 1, 0.f);
cv::Rect roi((int)x0, (int)y0, mask->width, mask->height);
mask_img = ch[col](roi);
}
cv::bitwise_or(imgMask, mask_img, mask_img);
merge(ch, 3, img);
}

Expand Down
7 changes: 5 additions & 2 deletions demo/csrc/cpp/utils/visualize.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ class Visualize {
rect.left, rect.top, rect.right, rect.bottom, label_id, score);
if (mask) {
fprintf(stdout, "mask %d, height=%d, width=%d\n", index, mask->height, mask->width);
auto x0 = (int)std::max(std::floor(rect.left) - 1, 0.f);
auto y0 = (int)std::max(std::floor(rect.top) - 1, 0.f);
int x0 = 0, y0 = 0, img_h=img_.size().height, img_w =img_.size().width ;
if (img_h != (int)mask->height || img_w != (int)mask->width ) { // maskrcnn
x0 = (int)std::max(std::floor(rect.left) - 1, 0.f);
y0 = (int)std::max(std::floor(rect.top) - 1, 0.f);
}
add_instance_mask({x0, y0}, rand(), mask->data, mask->height, mask->width);
}
add_bbox(rect, label_id, score);
Expand Down
10 changes: 6 additions & 4 deletions demo/python/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def main():
if masks[index].size:
mask = masks[index]
blue, green, red = cv2.split(img)

x0 = int(max(math.floor(bbox[0]) - 1, 0))
y0 = int(max(math.floor(bbox[1]) - 1, 0))
mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
if mask.shape == img.shape[:2]: # rtmdet-inst
mask_img = blue
else: # maskrcnn
x0 = int(max(math.floor(bbox[0]) - 1, 0))
y0 = int(max(math.floor(bbox[1]) - 1, 0))
mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
cv2.bitwise_or(mask, mask_img, mask_img)
img = cv2.merge([blue, green, red])

Expand Down
2 changes: 1 addition & 1 deletion docs/en/07-developer-guide/regression_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ python ./tools/regression_test.py \
--log-level INFO
```

## 3. Regression Test Tonfiguration
## 3. Regression Test Configuration

### Example and parameter description

Expand Down
4 changes: 3 additions & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
if 'mask_thr_binary' in params:
type = 'ResizeInstanceMask' # for instance-seg
params['is_resize_mask'] = False # resize and crop mask default
# resize and crop mask to origin image
params['is_resize_mask'] = True

if get_backend(self.deploy_cfg) == Backend.RKNN:
if 'YOLO' in self.model_cfg.model.type or \
'RTMDet' in self.model_cfg.model.type:
Expand Down
39 changes: 23 additions & 16 deletions mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,25 +241,30 @@ def postprocessing_results(self,
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
export_postprocess_mask = False
if self.deploy_cfg is not None:
mmdet_deploy_cfg = get_post_processing_params(
self.deploy_cfg)
# this flag enable postprocess when export.
export_postprocess_mask = mmdet_deploy_cfg.get(
'export_postprocess_mask', False)
if not export_postprocess_mask:
masks = End2EndModel.postprocessing_masks(
dets[:, :4], masks, ori_w, ori_h, self.device)
if model_type == 'RTMDet':
export_postprocess_mask = True
else:
masks = masks[:, :img_h, :img_w]
export_postprocess_mask = False
if self.deploy_cfg is not None:
mmdet_deploy_cfg = get_post_processing_params(
self.deploy_cfg)
# this flag enable postprocess when export.
export_postprocess_mask = mmdet_deploy_cfg.get(
'export_postprocess_mask', False)
if not export_postprocess_mask:
masks = End2EndModel.postprocessing_masks(
dets[:, :4], masks, ori_w, ori_h, self.device)
else:
masks = masks[:, :img_h, :img_w]
# avoid to resize masks with zero dim
if export_postprocess_mask and rescale and masks.shape[0] != 0:
masks = F.interpolate(
masks = torch.nn.functional.interpolate(
masks.unsqueeze(0),
size=[
math.ceil(masks.shape[-2] / scale_factor[0]),
math.ceil(masks.shape[-1] / scale_factor[1])
math.ceil(masks.shape[-2] /
img_metas[i]['scale_factor'][0]),
math.ceil(masks.shape[-1] /
img_metas[i]['scale_factor'][1])
])[..., :ori_h, :ori_w]
masks = masks.squeeze(0)
if masks.dtype != bool:
Expand Down Expand Up @@ -872,8 +877,10 @@ def forward(self,
ori_h, ori_w = data_samples[0].ori_shape[:2]
for bbox, mask in zip(dets, masks):
img_mask = np.zeros((ori_h, ori_w), dtype=np.uint8)
left = int(max(np.floor(bbox[0]) - 1, 0))
top = int(max(np.floor(bbox[1]) - 1, 0))
left, top = 0, 0
if not (ori_h == mask.shape[0] and ori_w == mask.shape[1]):
left = int(max(np.floor(bbox[0]) - 1, 0))
top = int(max(np.floor(bbox[1]) - 1, 0))
img_mask[top:top + mask.shape[0],
left:left + mask.shape[1]] = mask
segm_results.append(torch.from_numpy(img_mask))
Expand Down
Loading

0 comments on commit 123b9fc

Please sign in to comment.