From dbb5e75fad648514583da72e9724492a21a30c50 Mon Sep 17 00:00:00 2001 From: Mehdi Abaakouk Date: Fri, 25 Sep 2020 16:46:03 +0200 Subject: [PATCH] feat: use DTO for NCNN init parameters --- src/apidata.h | 26 +++++++++++++ src/backends/ncnn/ncnnlib.cc | 54 +++++++-------------------- src/backends/ncnn/ncnnlib.h | 14 +++---- src/http/dto/ncnn.hpp | 71 ++++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 47 deletions(-) create mode 100644 src/http/dto/ncnn.hpp diff --git a/src/apidata.h b/src/apidata.h index 9ec0e2f763..cd53b5a2f4 100644 --- a/src/apidata.h +++ b/src/apidata.h @@ -38,6 +38,7 @@ #include #include #include +#include "oatpp/parser/json/mapping/ObjectMapper.hpp" namespace dd { @@ -288,6 +289,31 @@ namespace dd */ void toJDoc(JDoc &jd) const; + /** + * \brief converts APIData to oat++ DTO + */ + template inline std::shared_ptr createSharedDTO() const + { + rapidjson::Document d; + d.SetObject(); + toJDoc(reinterpret_cast(d)); + + rapidjson::StringBuffer buffer; + rapidjson::Writer, + rapidjson::UTF8<>, rapidjson::CrtAllocator, + rapidjson::kWriteNanAndInfFlag> + writer(buffer); + bool done = d.Accept(writer); + if (!done) + throw DataConversionException("JSON rendering failed"); + + std::shared_ptr object_mapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + return object_mapper + ->readFromString>(buffer.GetString()) + .getPtr(); + } + /** * \brief converts APIData to rapidjson JSON value * @param jd JSON Document hosting the destination JSON value diff --git a/src/backends/ncnn/ncnnlib.cc b/src/backends/ncnn/ncnnlib.cc index b6732b58ec..0c6dc00292 100644 --- a/src/backends/ncnn/ncnnlib.cc +++ b/src/backends/ncnn/ncnnlib.cc @@ -22,7 +22,6 @@ #include "outputconnectorstrategy.h" #include #include -#include "utils/utils.hpp" // NCNN #include "ncnnlib.h" @@ -53,10 +52,10 @@ namespace dd { this->_libname = "ncnn"; _net = new ncnn::Net(); - _net->opt.num_threads = _threads; + _net->opt.num_threads = 1; _net->opt.blob_allocator = &_blob_pool_allocator; _net->opt.workspace_allocator = &_workspace_pool_allocator; - _net->opt.lightmode = _lightmode; + _net->opt.lightmode = true; } template _libname = "ncnn"; _net = tl._net; tl._net = nullptr; - _nclasses = tl._nclasses; - _threads = tl._threads; _timeserie = tl._timeserie; _old_height = tl._old_height; - _inputBlob = tl._inputBlob; - _outputBlob = tl._outputBlob; + _init_dto = tl._init_dto; } template ::init_mllib(const APIData &ad) { + _init_dto = ad.createSharedDTO(); + bool use_fp32 = (ad.has("datatype") && ad.get("datatype").get() == "fp32"); // default is fp16 @@ -124,35 +122,11 @@ namespace dd _old_height = this->_inputc.height(); _net->set_input_h(_old_height); - if (ad.has("nclasses")) - _nclasses = ad.get("nclasses").get(); - - if (ad.has("threads")) - _threads = ad.get("threads").get(); - else - _threads = dd_utils::my_hardware_concurrency(); - _timeserie = this->_inputc._timeserie; if (_timeserie) this->_mltype = "timeserie"; - if (ad.has("lightmode")) - { - _lightmode = ad.get("lightmode").get(); - _net->opt.lightmode = _lightmode; - } - - // setting the value of Input Layer - if (ad.has("inputblob")) - { - _inputBlob = ad.get("inputblob").get(); - } - // setting the final Output Layer - if (ad.has("outputblob")) - { - _outputBlob = ad.get("outputblob").get(); - } - + _net->opt.lightmode = _init_dto->lightmode; _blob_pool_allocator.set_size_compare_ratio(0.0f); _workspace_pool_allocator.set_size_compare_ratio(0.5f); model_type(this->_mlmodel._params, this->_mltype); @@ -213,8 +187,8 @@ namespace dd ncnn::Extractor ex = _net->create_extractor(); - ex.set_num_threads(_threads); - ex.input(_inputBlob.c_str(), inputc._in); + ex.set_num_threads(_init_dto->threads); + ex.input(_init_dto->inputBlob->c_str(), inputc._in); APIData ad_output = ad.getobj("parameters").getobj("output"); @@ -237,8 +211,7 @@ namespace dd } // Extract detection or classification - int ret = 0; - std::string out_blob = _outputBlob; + std::string out_blob = _init_dto->outputBlob.std_str(); if (out_blob.empty()) { if (bbox == true) @@ -250,7 +223,7 @@ namespace dd else out_blob = "prob"; } - ret = ex.extract(out_blob.c_str(), inputc._out); + int ret = ex.extract(out_blob.c_str(), inputc._out); if (ret == -1) { throw MLLibInternalException("NCNN internal error"); @@ -277,8 +250,8 @@ namespace dd { best = ad_output.get("best").get(); } - if (best == -1 || best > _nclasses) - best = _nclasses; + if (best == -1 || best > _init_dto->nclasses) + best = _init_dto->nclasses; if (bbox == true) { @@ -408,7 +381,8 @@ namespace dd vrad.push_back(rad); tout.add_results(vrad); - out.add("nclasses", this->_nclasses); + int nclasses = this->_init_dto->nclasses; + out.add("nclasses", nclasses); if (bbox == true) out.add("bbox", true); out.add("roi", false); diff --git a/src/backends/ncnn/ncnnlib.h b/src/backends/ncnn/ncnnlib.h index 513d06cd84..9ed5c270be 100644 --- a/src/backends/ncnn/ncnnlib.h +++ b/src/backends/ncnn/ncnnlib.h @@ -22,12 +22,15 @@ #ifndef NCNNLIB_H #define NCNNLIB_H +#include "apidata.h" +#include "utils/utils.hpp" + +#include "http/dto/ncnn.hpp" + // NCNN #include "net.h" #include "ncnnmodel.h" -#include "apidata.h" - namespace dd { template _init_dto; static ncnn::UnlockedPoolAllocator _blob_pool_allocator; static ncnn::PoolAllocator _workspace_pool_allocator; protected: - int _threads = 1; int _old_height = -1; - std::string _inputBlob = "data"; - std::string _outputBlob; }; + } #endif diff --git a/src/http/dto/ncnn.hpp b/src/http/dto/ncnn.hpp new file mode 100644 index 0000000000..4828dea8c4 --- /dev/null +++ b/src/http/dto/ncnn.hpp @@ -0,0 +1,71 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef HTTP_DTO_NCNN_H +#define HTTP_DTO_NCNN_H + +#include "dd_config.h" +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" + +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class NcnnInitDto : public oatpp::DTO +{ + DTO_INIT(NcnnInitDto, DTO /* extends */) + + DTO_FIELD_INFO(nclasses) + { + info->description = "number of output classes (`supervised` service " + "type), classification only"; + }; + DTO_FIELD(Int32, nclasses) = 0; + + DTO_FIELD_INFO(threads) + { + info->description = "number of threads"; + }; + DTO_FIELD(Int32, threads) = dd::dd_utils::my_hardware_concurrency(); + + DTO_FIELD_INFO(lightmode) + { + info->description = "enable light mode"; + }; + DTO_FIELD(Boolean, lightmode) = true; + + DTO_FIELD_INFO(inputBlob) + { + info->description = "network input blob name"; + }; + DTO_FIELD(String, inputBlob) = "data"; + + DTO_FIELD_INFO(outputBlob) + { + info->description = "network output blob name (default depends on " + "network type(ie prob or " + "rnn_pred or probs or detection_out)"; + }; + DTO_FIELD(String, outputBlob); +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section + +#endif