Skip to content

Commit

Permalink
feat: use DTO for NCNN init parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
sileht committed Jan 21, 2021
1 parent 566e5fb commit dbb5e75
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 47 deletions.
26 changes: 26 additions & 0 deletions src/apidata.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <vector>
#include <sstream>
#include <typeinfo>
#include "oatpp/parser/json/mapping/ObjectMapper.hpp"

namespace dd
{
Expand Down Expand Up @@ -288,6 +289,31 @@ namespace dd
*/
void toJDoc(JDoc &jd) const;

/**
* \brief converts APIData to oat++ DTO
*/
template <typename T> inline std::shared_ptr<T> createSharedDTO() const
{
rapidjson::Document d;
d.SetObject();
toJDoc(reinterpret_cast<JDoc &>(d));

rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer, rapidjson::UTF8<>,
rapidjson::UTF8<>, rapidjson::CrtAllocator,
rapidjson::kWriteNanAndInfFlag>
writer(buffer);
bool done = d.Accept(writer);
if (!done)
throw DataConversionException("JSON rendering failed");

std::shared_ptr<oatpp::data::mapping::ObjectMapper> object_mapper
= oatpp::parser::json::mapping::ObjectMapper::createShared();
return object_mapper
->readFromString<oatpp::Object<T>>(buffer.GetString())
.getPtr();
}

/**
* \brief converts APIData to rapidjson JSON value
* @param jd JSON Document hosting the destination JSON value
Expand Down
54 changes: 14 additions & 40 deletions src/backends/ncnn/ncnnlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "outputconnectorstrategy.h"
#include <thread>
#include <algorithm>
#include "utils/utils.hpp"

// NCNN
#include "ncnnlib.h"
Expand Down Expand Up @@ -53,10 +52,10 @@ namespace dd
{
this->_libname = "ncnn";
_net = new ncnn::Net();
_net->opt.num_threads = _threads;
_net->opt.num_threads = 1;
_net->opt.blob_allocator = &_blob_pool_allocator;
_net->opt.workspace_allocator = &_workspace_pool_allocator;
_net->opt.lightmode = _lightmode;
_net->opt.lightmode = true;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -69,12 +68,9 @@ namespace dd
this->_libname = "ncnn";
_net = tl._net;
tl._net = nullptr;
_nclasses = tl._nclasses;
_threads = tl._threads;
_timeserie = tl._timeserie;
_old_height = tl._old_height;
_inputBlob = tl._inputBlob;
_outputBlob = tl._outputBlob;
_init_dto = tl._init_dto;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -94,6 +90,8 @@ namespace dd
void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::init_mllib(const APIData &ad)
{
_init_dto = ad.createSharedDTO<NcnnInitDto>();

bool use_fp32 = (ad.has("datatype")
&& ad.get("datatype").get<std::string>()
== "fp32"); // default is fp16
Expand Down Expand Up @@ -124,35 +122,11 @@ namespace dd
_old_height = this->_inputc.height();
_net->set_input_h(_old_height);

if (ad.has("nclasses"))
_nclasses = ad.get("nclasses").get<int>();

if (ad.has("threads"))
_threads = ad.get("threads").get<int>();
else
_threads = dd_utils::my_hardware_concurrency();

_timeserie = this->_inputc._timeserie;
if (_timeserie)
this->_mltype = "timeserie";

if (ad.has("lightmode"))
{
_lightmode = ad.get("lightmode").get<bool>();
_net->opt.lightmode = _lightmode;
}

// setting the value of Input Layer
if (ad.has("inputblob"))
{
_inputBlob = ad.get("inputblob").get<std::string>();
}
// setting the final Output Layer
if (ad.has("outputblob"))
{
_outputBlob = ad.get("outputblob").get<std::string>();
}

_net->opt.lightmode = _init_dto->lightmode;
_blob_pool_allocator.set_size_compare_ratio(0.0f);
_workspace_pool_allocator.set_size_compare_ratio(0.5f);
model_type(this->_mlmodel._params, this->_mltype);
Expand Down Expand Up @@ -213,8 +187,8 @@ namespace dd

ncnn::Extractor ex = _net->create_extractor();

ex.set_num_threads(_threads);
ex.input(_inputBlob.c_str(), inputc._in);
ex.set_num_threads(_init_dto->threads);
ex.input(_init_dto->inputBlob->c_str(), inputc._in);

APIData ad_output = ad.getobj("parameters").getobj("output");

Expand All @@ -237,8 +211,7 @@ namespace dd
}

// Extract detection or classification
int ret = 0;
std::string out_blob = _outputBlob;
std::string out_blob = _init_dto->outputBlob.std_str();
if (out_blob.empty())
{
if (bbox == true)
Expand All @@ -250,7 +223,7 @@ namespace dd
else
out_blob = "prob";
}
ret = ex.extract(out_blob.c_str(), inputc._out);
int ret = ex.extract(out_blob.c_str(), inputc._out);
if (ret == -1)
{
throw MLLibInternalException("NCNN internal error");
Expand All @@ -277,8 +250,8 @@ namespace dd
{
best = ad_output.get("best").get<int>();
}
if (best == -1 || best > _nclasses)
best = _nclasses;
if (best == -1 || best > _init_dto->nclasses)
best = _init_dto->nclasses;

if (bbox == true)
{
Expand Down Expand Up @@ -408,7 +381,8 @@ namespace dd

vrad.push_back(rad);
tout.add_results(vrad);
out.add("nclasses", this->_nclasses);
int nclasses = this->_init_dto->nclasses;
out.add("nclasses", nclasses);
if (bbox == true)
out.add("bbox", true);
out.add("roi", false);
Expand Down
14 changes: 7 additions & 7 deletions src/backends/ncnn/ncnnlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
#ifndef NCNNLIB_H
#define NCNNLIB_H

#include "apidata.h"
#include "utils/utils.hpp"

#include "http/dto/ncnn.hpp"

// NCNN
#include "net.h"
#include "ncnnmodel.h"

#include "apidata.h"

namespace dd
{
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -53,20 +56,17 @@ namespace dd

public:
ncnn::Net *_net = nullptr;
int _nclasses = 0;
bool _timeserie = false;
bool _lightmode = true;

private:
std::shared_ptr<NcnnInitDto> _init_dto;
static ncnn::UnlockedPoolAllocator _blob_pool_allocator;
static ncnn::PoolAllocator _workspace_pool_allocator;

protected:
int _threads = 1;
int _old_height = -1;
std::string _inputBlob = "data";
std::string _outputBlob;
};

}

#endif
71 changes: 71 additions & 0 deletions src/http/dto/ncnn.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/**
* DeepDetect
* Copyright (c) 2021 Jolibrain SASU
* Author: Mehdi Abaakouk <mehdi.abaakouk@jolibrain.com>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef HTTP_DTO_NCNN_H
#define HTTP_DTO_NCNN_H

#include "dd_config.h"
#include "oatpp/core/Types.hpp"
#include "oatpp/core/macro/codegen.hpp"

#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section

class NcnnInitDto : public oatpp::DTO
{
DTO_INIT(NcnnInitDto, DTO /* extends */)

DTO_FIELD_INFO(nclasses)
{
info->description = "number of output classes (`supervised` service "
"type), classification only";
};
DTO_FIELD(Int32, nclasses) = 0;

DTO_FIELD_INFO(threads)
{
info->description = "number of threads";
};
DTO_FIELD(Int32, threads) = dd::dd_utils::my_hardware_concurrency();

DTO_FIELD_INFO(lightmode)
{
info->description = "enable light mode";
};
DTO_FIELD(Boolean, lightmode) = true;

DTO_FIELD_INFO(inputBlob)
{
info->description = "network input blob name";
};
DTO_FIELD(String, inputBlob) = "data";

DTO_FIELD_INFO(outputBlob)
{
info->description = "network output blob name (default depends on "
"network type(ie prob or "
"rnn_pred or probs or detection_out)";
};
DTO_FIELD(String, outputBlob);
};

#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section

#endif

0 comments on commit dbb5e75

Please sign in to comment.