Skip to content

Commit

Permalink
feat(tensorrt): Add support for onnx image classification models
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and sileht committed Oct 21, 2020
1 parent b43a6bc commit a8b81f2
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 63 deletions.
3 changes: 2 additions & 1 deletion src/backends/tensorrt/tensorrtinputconns.cc
Expand Up @@ -38,7 +38,8 @@ namespace dd
for (int c = 0; c < channels; ++c)
for (int h = 0; h < _height; ++h)
for (int w = 0; w < _width; ++w)
fbuf[offset++] = cvbuf[(converted.cols * h + w) * channels + c];
fbuf[offset++]
= _scale * cvbuf[(converted.cols * h + w) * channels + c];
}

void ImgTensorRTInputFileConn::applyMeanToRTBuf(int channels, int i)
Expand Down
196 changes: 138 additions & 58 deletions src/backends/tensorrt/tensorrtlib.cc
Expand Up @@ -24,6 +24,7 @@
#include "tensorrtinputconns.h"
#include "utils/apitools.h"
#include "NvInferPlugin.h"
#include "NvOnnxParser.h"
#include "protoUtils.h"
#include <cuda_runtime_api.h>
#include <string>
Expand All @@ -39,7 +40,12 @@ namespace dd
fileops::list_directory(repo, true, false, false, lfiles);
for (std::string s : lfiles)
{
if (s.find(engineFileName) != std::string::npos)
// Ommiting directory name
auto fstart = s.find_last_of("/");
if (fstart == std::string::npos)
fstart = 0;

if (s.find(engineFileName, fstart) != std::string::npos)
{
std::string bs_str;
for (auto it = s.crbegin(); it != s.crend(); ++it)
Expand Down Expand Up @@ -134,6 +140,10 @@ namespace dd
_max_batch_size = nmbs;
this->_logger->info("setting max batch size to {}", _max_batch_size);
}
if (ad.has("nclasses"))
{
_nclasses = ad.get("nclasses").get<int>();
}

if (ad.has("dla"))
_dla = ad.get("dla").get<int>();
Expand Down Expand Up @@ -244,6 +254,114 @@ namespace dd
return 0;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
nvinfer1::ICudaEngine *
TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::read_engine_from_caffe(const std::string &out_blob)
{
int fixcode = fixProto(this->_mlmodel._repo + "/" + "net_tensorRT.proto",
this->_mlmodel._def);
switch (fixcode)
{
case 1:
this->_logger->error("TRT backend could not open model prototxt");
break;
case 2:
this->_logger->error("TRT backend could not write "
"transformed model prototxt");
break;
default:
break;
}

nvinfer1::INetworkDefinition *network = _builder->createNetworkV2(0U);
nvcaffeparser1::ICaffeParser *caffeParser
= nvcaffeparser1::createCaffeParser();

const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor
= caffeParser->parse(
std::string(this->_mlmodel._repo + "/" + "net_tensorRT.proto")
.c_str(),
this->_mlmodel._weights.c_str(), *network, _datatype);
if (!blobNameToTensor)
throw MLLibInternalException("Error while parsing caffe model "
"for conversion to TensorRT");

network->markOutput(*blobNameToTensor->find(out_blob.c_str()));

if (out_blob == "detection_out")
network->markOutput(*blobNameToTensor->find("keep_count"));
_builder->setMaxBatchSize(_max_batch_size);
_builderc->setMaxWorkspaceSize(_max_workspace_size);

network->getLayer(0)->setPrecision(nvinfer1::DataType::kFLOAT);

nvinfer1::ILayer *outl = NULL;
int idx = network->getNbLayers() - 1;
while (outl == NULL)
{
nvinfer1::ILayer *l = network->getLayer(idx);
if (strcmp(l->getName(), out_blob.c_str()) == 0)
{
outl = l;
break;
}
idx--;
}
// force output to be float32
outl->setPrecision(nvinfer1::DataType::kFLOAT);
nvinfer1::ICudaEngine *engine
= _builder->buildEngineWithConfig(*network, *_builderc);

network->destroy();
if (caffeParser != nullptr)
caffeParser->destroy();

return engine;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
nvinfer1::ICudaEngine *
TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::read_engine_from_onnx()
{
const auto explicitBatch
= 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);

nvinfer1::INetworkDefinition *network
= _builder->createNetworkV2(explicitBatch);

nvonnxparser::IParser *onnxParser
= nvonnxparser::createParser(*network, trtLogger);
onnxParser->parseFromFile(this->_mlmodel._model.c_str(),
int(nvinfer1::ILogger::Severity::kWARNING));

if (onnxParser->getNbErrors() != 0)
{
for (int i = 0; i < onnxParser->getNbErrors(); ++i)
{
this->_logger->error(onnxParser->getError(i)->desc());
}
throw MLLibInternalException(
"Error while parsing onnx model for conversion to "
"TensorRT");
}
_builder->setMaxBatchSize(_max_batch_size);
_builderc->setMaxWorkspaceSize(_max_workspace_size);

nvinfer1::ICudaEngine *engine
= _builder->buildEngineWithConfig(*network, *_builderc);

network->destroy();
if (onnxParser != nullptr)
onnxParser->destroy();

return engine;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
int TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
Expand Down Expand Up @@ -293,7 +411,12 @@ namespace dd
"timeseries not yet implemented over tensorRT backend");
}

_nclasses = findNClasses(this->_mlmodel._def, _bbox);
if (_nclasses == 0)
{
this->_logger->info("try to determine number of classes...");
_nclasses = findNClasses(this->_mlmodel._def, _bbox);
}

if (_bbox)
_top_k = findTopK(this->_mlmodel._def);

Expand Down Expand Up @@ -335,65 +458,25 @@ namespace dd

if (!engineRead)
{
nvinfer1::ICudaEngine *le = nullptr;

int fixcode
= fixProto(this->_mlmodel._repo + "/" + "net_tensorRT.proto",
this->_mlmodel._def);
switch (fixcode)
if (this->_mlmodel._model.find("net_tensorRT.proto")
!= std::string::npos
|| !this->_mlmodel._def.empty())
{
case 1:
this->_logger->error(
"TRT backend could not open model prototxt");
break;
case 2:
this->_logger->error("TRT backend could not write "
"transformed model prototxt");
break;
default:
break;
le = read_engine_from_caffe(out_blob);
}

nvinfer1::INetworkDefinition *network
= _builder->createNetworkV2(0U);
nvcaffeparser1::ICaffeParser *caffeParser
= nvcaffeparser1::createCaffeParser();

const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor
= caffeParser->parse(std::string(this->_mlmodel._repo + "/"
+ "net_tensorRT.proto")
.c_str(),
this->_mlmodel._weights.c_str(), *network,
_datatype);
if (!blobNameToTensor)
throw MLLibInternalException("Error while parsing caffe model "
"for conversion to TensorRT");

network->markOutput(*blobNameToTensor->find(out_blob.c_str()));

if (out_blob == "detection_out")
network->markOutput(*blobNameToTensor->find("keep_count"));
_builder->setMaxBatchSize(_max_batch_size);
_builderc->setMaxWorkspaceSize(_max_workspace_size);

network->getLayer(0)->setPrecision(nvinfer1::DataType::kFLOAT);

nvinfer1::ILayer *outl = NULL;
int idx = network->getNbLayers() - 1;
while (outl == NULL)
else if (this->_mlmodel._model.find("net_tensorRT.onnx")
!= std::string::npos)
{
nvinfer1::ILayer *l = network->getLayer(idx);
if (strcmp(l->getName(), out_blob.c_str()) == 0)
{
outl = l;
break;
}
idx--;
le = read_engine_from_onnx();
}
else
{
throw MLLibInternalException(
"No model to parse for conversion to TensorRT");
}
// force output to be float32
outl->setPrecision(nvinfer1::DataType::kFLOAT);

nvinfer1::ICudaEngine *le
= _builder->buildEngineWithConfig(*network, *_builderc);
_engine = std::shared_ptr<nvinfer1::ICudaEngine>(
le, [=](nvinfer1::ICudaEngine *e) { e->destroy(); });

Expand All @@ -407,9 +490,6 @@ namespace dd
trtModelStream->size());
trtModelStream->destroy();
}

network->destroy();
caffeParser->destroy();
}

_context = std::shared_ptr<nvinfer1::IExecutionContext>(
Expand Down
4 changes: 4 additions & 0 deletions src/backends/tensorrt/tensorrtlib.h
Expand Up @@ -143,6 +143,10 @@ namespace dd
std::mutex
_net_mutex; /**< mutex around net, e.g. no concurrent predict calls as
net is not re-instantiated. Use batches instead. */

nvinfer1::ICudaEngine *read_engine_from_caffe(const std::string &out_blob);

nvinfer1::ICudaEngine *read_engine_from_onnx();
};

}
Expand Down
25 changes: 22 additions & 3 deletions src/backends/tensorrt/tensorrtmodel.cc
Expand Up @@ -28,16 +28,21 @@ namespace dd
static std::string weights = ".caffemodel";
static std::string corresp = "corresp";
static std::string meanf = "mean.binaryproto";

static std::string model_name = "net_tensorRT";
static std::string caffe_model_name = model_name + ".proto";
static std::string onnx_model_name = model_name + ".onnx";

std::unordered_set<std::string> lfiles;
int e = fileops::list_directory(_repo, true, false, false, lfiles);
if (e != 0)
{
logger->error("error reading or listing caffe models in repository {}",
logger->error("error reading or listing models in repository {}",
_repo);
return 1;
}
std::string deployf, weightsf, correspf;
long int weight_t = -1;
std::string deployf, weightsf, correspf, modelf;
long int weight_t = -1, model_t = -1;
auto hit = lfiles.begin();
while (hit != lfiles.end())
{
Expand All @@ -57,6 +62,16 @@ namespace dd
}
else if ((*hit).find(corresp) != std::string::npos)
correspf = (*hit);
else if ((*hit).find(caffe_model_name) != std::string::npos
|| (*hit).find(onnx_model_name) != std::string::npos)
{
long int wt = fileops::file_last_modif(*hit);
if (wt > model_t)
{
modelf = (*hit);
model_t = wt;
}
}
else if ((*hit).find("~") != std::string::npos
|| (*hit).find(".prototxt") == std::string::npos)
{
Expand All @@ -67,12 +82,16 @@ namespace dd
deployf = (*hit);
++hit;
}

if (_def.empty())
_def = deployf;
if (_weights.empty())
_weights = weightsf;
if (_corresp.empty())
_corresp = correspf;
if (_model.empty())
_model = modelf;

return 0;
}
}
1 change: 1 addition & 0 deletions src/backends/tensorrt/tensorrtmodel.h
Expand Up @@ -53,6 +53,7 @@ namespace dd

int read_from_repository(const std::shared_ptr<spdlog::logger> &logger);

std::string _model;
std::string _def;
std::string _weights;
bool _has_mean_file = false;
Expand Down
9 changes: 8 additions & 1 deletion tests/CMakeLists.txt
Expand Up @@ -267,12 +267,19 @@ if (GTEST_FOUND)
"squeezenet_ssd_trt"
)
DOWNLOAD_DATASET(
"Downloading age test set"
"Age test set"
"https://deepdetect.com/models/init/desktop/images/classification/age_real.tar.gz"
"examples/trt/age_real"
"age_real.tar.gz"
"deploy.prototxt"
)
DOWNLOAD_DATASET(
"ONNX resnet model"
"https://deepdetect.com/models/init/desktop/images/classification/resnet_onnx_trt.tar.gz"
"examples/trt"
"resnet_onnx_trt.tar.gz"
"resnet_onnx_trt"
)

if(USE_JSON_API)
REGISTER_TEST(ut_tensorrtapi ut-tensorrtapi.cc)
Expand Down

0 comments on commit a8b81f2

Please sign in to comment.