Skip to content

Commit

Permalink
fix: ONNX tensorrt engine with correct enqueueV2
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Dec 21, 2020
1 parent 41d5375 commit 1aede85
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
54 changes: 44 additions & 10 deletions src/backends/tensorrt/tensorrtlib.cc
Expand Up @@ -102,6 +102,7 @@ namespace dd
_inputIndex = tl._inputIndex;
_outputIndex0 = tl._outputIndex0;
_outputIndex1 = tl._outputIndex1;
_explicit_batch = tl._explicit_batch;
_floatOut = tl._floatOut;
_keepCount = tl._keepCount;
}
Expand Down Expand Up @@ -340,12 +341,13 @@ namespace dd
TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::read_engine_from_onnx()
{
// XXX: TensorRT at the moment only supports explicitBatch models with ONNX
const auto explicitBatch
= 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);

nvinfer1::INetworkDefinition *network
= _builder->createNetworkV2(explicitBatch);
_explicit_batch = true;

nvonnxparser::IParser *onnxParser
= nvonnxparser::createParser(*network, trtLogger);
Expand Down Expand Up @@ -447,9 +449,14 @@ namespace dd
{
int bs = findEngineBS(this->_mlmodel._repo, _engineFileName);
if (bs != _max_batch_size && bs != -1)
this->_logger->warn(
"found existing engine with max_batch_size {}, using it",
bs);
{
throw MLLibBadParamException(
"found existing engine with max_batch_size "
+ std::to_string(bs) + " instead of "
+ std::to_string(_max_batch_size)
+ " / either delete it or set your maxBatchSize to "
+ std::to_string(bs));
}
std::ifstream file(this->_mlmodel._repo + "/" + _engineFileName
+ "_bs" + std::to_string(bs),
std::ios::binary);
Expand Down Expand Up @@ -509,14 +516,28 @@ namespace dd
trtModelStream->destroy();
}
}
else
{
if (this->_mlmodel._model.find("net_tensorRT.onnx")
!= std::string::npos)
_explicit_batch = true;
}

_context = std::shared_ptr<nvinfer1::IExecutionContext>(
_engine->createExecutionContext(),
[=](nvinfer1::IExecutionContext *e) { e->destroy(); });
_TRTContextReady = true;

_inputIndex = _engine->getBindingIndex("data");
_outputIndex0 = _engine->getBindingIndex(out_blob.c_str());
try
{
_inputIndex = _engine->getBindingIndex("data");
_outputIndex0 = _engine->getBindingIndex(out_blob.c_str());
}
catch (...)
{
throw MLLibInternalException("Cannot find or bind output layer "
+ out_blob);
}

if (_bbox)
{
Expand Down Expand Up @@ -590,6 +611,7 @@ namespace dd
cudaStream_t cstream;
cudaStreamCreate(&cstream);

bool enqueue_success = false;
while (true)
{

Expand All @@ -611,8 +633,14 @@ namespace dd
num_processed * 3 * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
_context->enqueue(num_processed, _buffers.data(), cstream,
nullptr);
if (!_explicit_batch)
enqueue_success = _context->enqueue(
num_processed, _buffers.data(), cstream, nullptr);
else
enqueue_success
= _context->enqueueV2(_buffers.data(), cstream, nullptr);
if (!enqueue_success)
throw MLLibInternalException("Failed TRT enqueue call");
cudaMemcpyAsync(_floatOut.data(),
_buffers.data()[_outputIndex0],
num_processed * _top_k * 7 * sizeof(float),
Expand Down Expand Up @@ -645,8 +673,14 @@ namespace dd
num_processed * 3 * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
_context->enqueue(num_processed, _buffers.data(), cstream,
nullptr);
if (!_explicit_batch)
enqueue_success = _context->enqueue(
num_processed, _buffers.data(), cstream, nullptr);
else
enqueue_success
= _context->enqueueV2(_buffers.data(), cstream, nullptr);
if (!enqueue_success)
throw MLLibInternalException("Failed TRT enqueue call");
cudaMemcpyAsync(_floatOut.data(),
_buffers.data()[_outputIndex0],
num_processed * _nclasses * sizeof(float),
Expand Down
3 changes: 3 additions & 0 deletions src/backends/tensorrt/tensorrtlib.h
Expand Up @@ -138,6 +138,9 @@ namespace dd
int _outputIndex0;
int _outputIndex1;

bool _explicit_batch
= false; /**< whether TRT uses explicit batch model (ONNX). */

std::vector<float> _floatOut;
std::vector<int> _keepCount;

Expand Down

0 comments on commit 1aede85

Please sign in to comment.