Skip to content

Commit

Permalink
feat(chain): add action to draw bboxes as trailing action
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and beniz committed Jan 25, 2022
1 parent a470c7b commit ae0a05f
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 33 deletions.
12 changes: 9 additions & 3 deletions demo/objdetect/README.md
Expand Up @@ -8,7 +8,7 @@ To run the code on your own image:

- Install the pre-trained model:

```
```bash
mkdir model
cd model
wget https://deepdetect.com/models/voc0712_dd.tar.gz
Expand All @@ -18,16 +18,22 @@ cd ..

- Start a DeepDetect server:

```
```bash
./dede
```

- Try object detection on an image

```
```bash
python objdetect.py --image /path/to/yourimage.jpg --confidence-threshold 0.1
```

- Alternatively, try the chain version. It will draw the bounding boxes on the image and return it.

```bash
python chaindetect.py --image /path/to/yourimage.jpg --save-path image.png
```

Notes:

- The VOC0712 model originates from https://github.com/weiliu89/caffe/tree/ssd and may not be very accurate on standard pictures
Expand Down
64 changes: 64 additions & 0 deletions demo/objdetect/chaindetect.py
@@ -0,0 +1,64 @@
import os, sys, argparse
from dd_client import DD
import cv2
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--image",help="path to image")
parser.add_argument("--port", help="DeepDetect port", type=int, default=8080)
parser.add_argument("--confidence-threshold",help="keep detections with confidence above threshold",type=float,default=0.1)
parser.add_argument("--save-path", help="Where to save resulting image")
args = parser.parse_args()

host = 'localhost'
sname = 'imgserv'
description = 'image classification'
mllib = 'caffe'
mltype = 'supervised'
nclasses = 21
width = height = 300
dd = DD(host, port=args.port)
dd.set_return_format(dd.RETURN_PYTHON)

# creating ML service
model_repo = os.getcwd() + '/model'
model = {'repository':model_repo}
parameters_input = {'connector':'image','width':width,'height':height}
parameters_mllib = {'nclasses':nclasses}
parameters_output = {}
dd.put_service(sname,model,description,mllib,
parameters_input,parameters_mllib,parameters_output,mltype)

# chain call
calls = []

parameters_input = {"keep_orig":True}
parameters_mllib = {'gpu':True}
parameters_output = {'bbox':True, 'confidence_threshold': args.confidence_threshold}
data = [args.image]
calls.append(dd.make_call(sname, data, parameters_input, parameters_mllib, parameters_output))

parameters_action = {"output_images":True, "write_prob": True}
# parameters_action["save_path"] = os.getcwd()
# parameters_action["save_img"] = True
calls.append(dd.make_action("draw_bbox", parameters_action, "img_bbox"))

detect = dd.post_chain("chain_ddetect",calls)
# print(detect)
if detect['status']['code'] != 200:
print('error',detect['status']['code'])
sys.exit()

predictions = detect['body']['predictions']
for p in predictions:
# get orig img dimensions
orig_img = cv2.imread(p['uri'])
width, height, _ = orig_img.shape

img = np.array(p["img_bbox"]["vals"])
img = img.reshape((width, height, 3))

if args.save_path:
cv2.imwrite(args.save_path, img)
cv2.imshow('img',img)
k = cv2.waitKey(0)
1 change: 0 additions & 1 deletion demo/objdetect/dd_client.py

This file was deleted.

66 changes: 58 additions & 8 deletions src/chain.cc
Expand Up @@ -25,6 +25,28 @@

namespace dd
{

void embed_model_output(
oatpp::UnorderedFields<oatpp::Any> &dest,
std::unordered_multimap<std::string, oatpp::Object<DTO::Prediction>>
&other_models_out,
const std::string &uri)
{
auto rhit_range = other_models_out.equal_range(uri);

for (auto rhit = rhit_range.first; rhit != rhit_range.second; ++rhit)
{
oatpp::String model_name = rhit->second->uri;
rhit->second->uri = nullptr;
if (dest->find(model_name) != dest->end())
throw ChainBadParamException(
"This key already exists and cannot be used to reference "
"model output: "
+ model_name->std_str());
(*dest)[model_name] = rhit->second;
}
}

oatpp::Object<DTO::ChainBody> ChainData::nested_chain_output()
{
// pre-compile models != first model
Expand All @@ -38,6 +60,7 @@ namespace dd
{
std::string model_id = (*hit).first;
std::string model_name = get_model_sname(model_id);

if (model_id == _first_id)
{
if ((*hit).second.has("dto"))
Expand Down Expand Up @@ -84,29 +107,56 @@ namespace dd
++hit;
}

// actions
std::unordered_map<std::string, APIData>::const_iterator ahit
= _action_data.begin();

while (ahit != _action_data.end())
{
std::string action_id = ahit->first;
const APIData &action_data = ahit->second;

if (action_data.has("output"))
{
auto out_body = ahit->second.get("output")
.get<oatpp::Any>()
.retrieve<oatpp::Object<DTO::PredictBody>>();

for (auto p : *out_body->predictions)
{
std::string uri = p->uri->std_str();
p->uri = action_id.c_str();
other_models_out.insert(
std::pair<std::string, oatpp::Object<DTO::Prediction>>(uri,
p));
}
}
++ahit;
}

// Return a DTO
auto chain_dto = DTO::ChainBody::createShared();
for (auto pred : *first_model_out->predictions)
{
oatpp::UnorderedFields<oatpp::Any> chain_pred
= oatpp_utils::dtoToUFields(pred);

// chain result at uri level
embed_model_output(chain_pred, other_models_out, pred->uri->std_str());

// chain results at prediction level
auto classes = oatpp::Vector<oatpp::Any>::createShared();

for (auto cls : *pred->classes)
{
std::string uri = cls->class_id->std_str();
cls->class_id = nullptr;
oatpp::UnorderedFields<oatpp::Any> class_preds
= oatpp_utils::dtoToUFields(cls);
auto rhit_range = other_models_out.equal_range(uri);

for (auto rhit = rhit_range.first; rhit != rhit_range.second;
++rhit)
if (cls->class_id != nullptr)
{
oatpp::String model_name = rhit->second->uri;
rhit->second->uri = nullptr;
(*class_preds)[model_name] = rhit->second;
std::string uri = cls->class_id->std_str();
cls->class_id = nullptr;
embed_model_output(class_preds, other_models_out, uri);
}
classes->push_back(class_preds);
}
Expand Down
125 changes: 125 additions & 0 deletions src/chain_actions.cc
Expand Up @@ -30,6 +30,7 @@
#endif
#include <unordered_set>
#include "utils/utils.hpp"
#include "dto/predict_out.hpp"

#ifdef USE_DLIB
#include "backends/dlib/dlib_actions.h"
Expand Down Expand Up @@ -292,6 +293,129 @@ namespace dd
cdata.add_action_data(_action_id, action_out);
}

cv::Scalar bbox_palette[]
= { { 82, 188, 227 }, { 196, 110, 49 }, { 39, 54, 227 },
{ 68, 227, 81 }, { 77, 157, 255 }, { 255, 112, 207 },
{ 240, 228, 65 }, { 94, 242, 151 }, { 236, 121, 242 },
{ 28, 77, 120 } };
size_t bbox_palette_size = 10;

void ImgsDrawBBoxAction::apply(APIData &model_out, ChainData &cdata)
{
std::vector<APIData> vad = model_out.getv("predictions");
APIData input_ad = model_out.getobj("input");

std::vector<cv::Mat> imgs
= model_out.getobj("input").get("imgs").get<std::vector<cv::Mat>>();
std::vector<std::pair<int, int>> imgs_size
= model_out.getobj("input")
.get("imgs_size")
.get<std::vector<std::pair<int, int>>>();
std::vector<cv::Mat> rimgs;
std::vector<std::string> uris;
auto pred_body = DTO::PredictBody::createShared();

bool save_img = _params->save_img;
int ref_thickness = _params->thickness;

std::string save_path = _params->save_path->std_str();
if (!save_path.empty())
save_path += "/";

for (size_t i = 0; i < vad.size(); i++)
{
std::string uri = vad.at(i).get("uri").get<std::string>();
uris.push_back(uri);

int im_cols = imgs.at(i).cols;
int im_rows = imgs.at(i).rows;
int orig_cols = imgs_size.at(i).second;
int orig_rows = imgs_size.at(i).first;

std::vector<APIData> ad_cls = vad.at(i).getv("classes");
cv::Mat rimg = imgs.at(i).clone();

// iterate bboxes per image
for (size_t j = 0; j < ad_cls.size(); j++)
{
APIData bbox = ad_cls.at(j).getobj("bbox");
std::string cat = ad_cls.at(j).get("cat").get<std::string>();
if (bbox.empty())
throw ActionBadParamException(
"draw action cannot find bbox object for uri " + uri);

double xmin = bbox.get("xmin").get<double>() / orig_cols * im_cols;
double ymin = bbox.get("ymin").get<double>() / orig_rows * im_rows;
double xmax = bbox.get("xmax").get<double>() / orig_cols * im_cols;
double ymax = bbox.get("ymax").get<double>() / orig_rows * im_rows;

// draw bbox
cv::Point pt1{ int(xmin), int(ymin) };
cv::Point pt2{ int(xmax), int(ymax) };
size_t cls_hash = std::hash<std::string>{}(cat);
cv::Scalar color = bbox_palette[cls_hash % bbox_palette_size];
cv::rectangle(rimg, pt1, pt2, cv::Scalar(255, 255, 255),
ref_thickness + 2);
cv::rectangle(rimg, pt1, pt2, color, ref_thickness);

// draw class & confidences
std::string label;
if (_params->write_cat)
label = cat;
if (_params->write_cat && _params->write_prob)
label += " - ";
if (_params->write_prob)
label += std::to_string(ad_cls.at(j).get("prob").get<double>());

// font size relatively to base opencv font size
float font_size = 2;
int x_txt = static_cast<int>(xmin + 5);
if (x_txt > im_cols - 15)
x_txt = im_cols - 15;
int y_txt = std::min(im_rows - 20,
static_cast<int>(ymax + 2 + font_size * 12));

cv::putText(rimg, label, cv::Point(x_txt, y_txt),
cv::FONT_HERSHEY_PLAIN, font_size,
cv::Scalar(255, 255, 255), ref_thickness + 2);
cv::putText(rimg, label, cv::Point(x_txt, y_txt),
cv::FONT_HERSHEY_PLAIN, font_size, color,
ref_thickness);
}

rimgs.push_back(rimg);

// save image if requested
if (save_img)
{
std::string puri = dd_utils::split(uri, '/').back();
std::string rimg_path = save_path + "bbox_" + puri + ".png";
this->_chain_logger->info("draw_bbox: Saved image to path {}",
rimg_path);
cv::imwrite(rimg_path, rimg);
}

if (_params->output_images)
{
auto action_pred = DTO::Prediction::createShared();

action_pred->vals = DTO::DTOVector<uint8_t>(std::vector<uint8_t>(
rimg.data, rimg.data + (rimg.total() * rimg.elemSize())));
action_pred->uri = uri.c_str();
pred_body->predictions->push_back(action_pred);
}
}

APIData action_out;
action_out.add("data_raw_img", rimgs);
action_out.add("cids", uris);
if (_params->output_images)
{
action_out.add("output", pred_body);
}
cdata.add_action_data(_action_id, action_out);
}

void ClassFilter::apply(APIData &model_out, ChainData &cdata)
{
if (_params->classes == nullptr)
Expand Down Expand Up @@ -376,6 +500,7 @@ namespace dd

CHAIN_ACTION("crop", ImgsCropAction)
CHAIN_ACTION("rotate", ImgsRotateAction)
CHAIN_ACTION("draw_bbox", ImgsDrawBBoxAction)
CHAIN_ACTION("filter", ClassFilter)
#ifdef USE_DLIB
CHAIN_ACTION("dlib_align_crop", DlibAlignCropAction)
Expand Down
16 changes: 16 additions & 0 deletions src/chain_actions.h
Expand Up @@ -129,6 +129,22 @@ namespace dd
void apply(APIData &model_out, ChainData &cdata);
};

class ImgsDrawBBoxAction : public ChainAction
{
public:
ImgsDrawBBoxAction(oatpp::Object<DTO::ChainCall> call_dto,
const std::shared_ptr<spdlog::logger> chain_logger)
: ChainAction(call_dto, chain_logger)
{
}

~ImgsDrawBBoxAction()
{
}

void apply(APIData &model_out, ChainData &cdata);
};

class ClassFilter : public ChainAction
{
public:
Expand Down

0 comments on commit ae0a05f

Please sign in to comment.