Skip to content

Commit

Permalink
feat(ml): inference for GAN generators with TensorRT backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Jun 2, 2021
1 parent 9dadb1c commit c93188c
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 49 deletions.
157 changes: 108 additions & 49 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ namespace dd
_explicit_batch = tl._explicit_batch;
_floatOut = tl._floatOut;
_keepCount = tl._keepCount;
_dims = tl._dims;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand Down Expand Up @@ -425,6 +426,13 @@ namespace dd
auto output_params = predict_dto->parameters->output;

std::string out_blob = "prob";
std::string extract_layer;
if (predict_dto->parameters->mllib->extract_layer != nullptr)
{
extract_layer
= predict_dto->parameters->mllib->extract_layer->std_str();
}

TInputConnectorStrategy inputc(this->_inputc);

if (!_TRTContextReady)
Expand Down Expand Up @@ -457,6 +465,8 @@ namespace dd
}
else if (_regression)
out_blob = "pred";
else if (!extract_layer.empty())
out_blob = extract_layer;

if (_nclasses == 0 && this->_mlmodel.is_caffe_source())
{
Expand Down Expand Up @@ -561,8 +571,17 @@ namespace dd

try
{
_inputIndex = _engine->getBindingIndex("data");
_outputIndex0 = _engine->getBindingIndex(out_blob.c_str());
_inputIndex = 0;
if (out_blob == "last")
_outputIndex0 = _engine->getNbBindings() - 1;
else
_outputIndex0 = _engine->getBindingIndex(out_blob.c_str());
_dims = _engine->getBindingDimensions(_outputIndex0);
if (_dims.nbDims >= 2)
this->_logger->info("detected output dimensions: [{}, {} {} {}]",
_dims.d[0], _dims.d[1],
_dims.nbDims > 2 ? _dims.d[2] : 0,
_dims.nbDims > 3 ? _dims.d[3] : 0);
}
catch (...)
{
Expand Down Expand Up @@ -603,6 +622,28 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
// GAN / raw output
else if (!extract_layer.empty())
{
_buffers.resize(2);
if (_dims.nbDims == 4)
_floatOut.resize(_max_batch_size * _dims.d[1] * _dims.d[2]
* _dims.d[3]);
else
throw MLLibBadParamException(
"raw/image output model requires 4 output dimensions");
if (inputc._bw)
cudaMalloc(&_buffers.data()[_inputIndex],
_max_batch_size * inputc._height * inputc._width
* sizeof(float));
else
cudaMalloc(&_buffers.data()[_inputIndex],
_max_batch_size * 3 * inputc._height
* inputc._width * sizeof(float));
cudaMalloc(&_buffers.data()[_outputIndex0],
_max_batch_size * _dims.d[1] * _dims.d[2]
* _dims.d[3] * sizeof(float));
}
else // classification / regression
{
_buffers.resize(2);
Expand Down Expand Up @@ -665,26 +706,27 @@ namespace dd

try
{
if (inputc._bw)
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
num_processed * inputc._height * inputc._width
* sizeof(float),
cudaMemcpyHostToDevice, cstream);
else
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
num_processed * 3 * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
if (!_explicit_batch)
enqueue_success = _context->enqueue(
num_processed, _buffers.data(), cstream, nullptr);
else
enqueue_success
= _context->enqueueV2(_buffers.data(), cstream, nullptr);
if (!enqueue_success)
throw MLLibInternalException("Failed TRT enqueue call");

if (_bbox)
{
if (inputc._bw)
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
num_processed * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
else
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
num_processed * 3 * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
if (!_explicit_batch)
enqueue_success = _context->enqueue(
num_processed, _buffers.data(), cstream, nullptr);
else
enqueue_success
= _context->enqueueV2(_buffers.data(), cstream, nullptr);
if (!enqueue_success)
throw MLLibInternalException("Failed TRT enqueue call");
cudaMemcpyAsync(_floatOut.data(),
_buffers.data()[_outputIndex0],
num_processed * _top_k * 7 * sizeof(float),
Expand All @@ -705,26 +747,17 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
// GAN/raw output
else if (!extract_layer.empty())
{
cudaMemcpyAsync(
_floatOut.data(), _buffers.data()[_outputIndex0],
num_processed * _floatOut.size() * sizeof(float),
cudaMemcpyDeviceToHost, cstream);
cudaStreamSynchronize(cstream);
}
else // classification / regression
{
if (inputc._bw)
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
num_processed * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
else
cudaMemcpyAsync(_buffers.data()[_inputIndex], inputc.data(),
num_processed * 3 * inputc._height
* inputc._width * sizeof(float),
cudaMemcpyHostToDevice, cstream);
if (!_explicit_batch)
enqueue_success = _context->enqueue(
num_processed, _buffers.data(), cstream, nullptr);
else
enqueue_success
= _context->enqueueV2(_buffers.data(), cstream, nullptr);
if (!enqueue_success)
throw MLLibInternalException("Failed TRT enqueue call");
cudaMemcpyAsync(_floatOut.data(),
_buffers.data()[_outputIndex0],
num_processed * _nclasses * sizeof(float),
Expand Down Expand Up @@ -836,6 +869,21 @@ namespace dd
throw MLLibBadParamException(
"timeseries not yet implemented over tensorRT backend");
}
else if (!extract_layer.empty())
{
for (int j = 0; j < num_processed; j++)
{
APIData rad;
if (!inputc._ids.empty())
rad.add("uri", inputc._ids.at(idoffset + j));
else
rad.add("uri", std::to_string(idoffset + j));
rad.add("loss", 0.0);
std::vector<double> vals(_floatOut.begin(), _floatOut.end());
rad.add("vals", vals);
vrad.push_back(rad);
}
}
else // classification / regression
{
for (int j = 0; j < num_processed; j++)
Expand Down Expand Up @@ -869,17 +917,28 @@ namespace dd

cudaStreamDestroy(cstream);

tout.add_results(vrad);

out.add("nclasses", this->_nclasses);
if (_bbox)
out.add("bbox", true);
if (_regression)
out.add("regression", true);
out.add("roi", false);
out.add("multibox_rois", false);
tout.finalize(ad.getobj("parameters").getobj("output"), out,
static_cast<MLModel *>(&this->_mlmodel));
if (extract_layer.empty())
{
tout.add_results(vrad);
out.add("nclasses", this->_nclasses);
if (_bbox)
out.add("bbox", true);
if (_regression)
out.add("regression", true);
out.add("roi", false);
out.add("multibox_rois", false);
tout.finalize(ad.getobj("parameters").getobj("output"),
out, // TODO; to output_params DTO
static_cast<MLModel *>(&this->_mlmodel));
}
else
{
UnsupervisedOutput unsupo;
unsupo.add_results(vrad);
unsupo.finalize(ad.getobj("parameters").getobj("output"),
out, // TODO: to output_params DTO
static_cast<MLModel *>(&this->_mlmodel));
}

if (ad.has("chain") && ad.get("chain").get<bool>())
{
Expand Down
2 changes: 2 additions & 0 deletions src/backends/tensorrt/tensorrtlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ namespace dd
std::vector<float> _floatOut;
std::vector<int> _keepCount;

nvinfer1::Dims _dims;

std::mutex
_net_mutex; /**< mutex around net, e.g. no concurrent predict calls as
net is not re-instantiated. Use batches instead. */
Expand Down
9 changes: 9 additions & 0 deletions src/dto/mllib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ namespace dd

DTO_FIELD(String, datatype) = "fp32";

DTO_FIELD_INFO(extract_layer)
{
info->description
= "Returns tensor values from an intermediate layer. If set to "
"'last', returns the values from last layer.";
}

DTO_FIELD(String, extract_layer);

// =====
// Libtorch options
DTO_FIELD_INFO(self_supervised)
Expand Down
7 changes: 7 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ if (USE_TENSORRT)
"examples/trt"
"resnet_onnx_trt.tar.gz"
"resnet_onnx_trt"
)
DOWNLOAD_DATASET(
"ONNX CycleGAN model"
"https://deepdetect.com/dd/examples/tensorrt/cyclegan_resnet_attn_onnx_trt.tar.gz"
"examples/trt"
"cyclegan_resnet_attn_onnx_trt.tar.gz"
"cyclegan_resnet_attn_onnx_trt"
)

if(USE_JSON_API)
Expand Down
42 changes: 42 additions & 0 deletions tests/ut-tensorrtapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ static std::string squeez_repo = "../examples/trt/squeezenet_ssd_trt/";
static std::string refinedet_repo = "../examples/trt/faces_512/";
static std::string age_repo = "../examples/trt/age_real/";
static std::string resnet_onnx_repo = "../examples/trt/resnet_onnx_trt/";
static std::string cyclegan_onnx_repo
= "../examples/trt/cyclegan_resnet_attn_onnx_trt/";

TEST(tensorrtapi, service_predict)
{
Expand Down Expand Up @@ -205,3 +207,43 @@ TEST(tensorrtapi, service_predict_onnx)
ASSERT_TRUE(jd["body"]["predictions"][0]["classes"][0]["prob"].GetDouble()
> 0.3);
}

TEST(tensorrtapi, service_predict_gan_onnx)
{
// create service
JsonAPI japi;
std::string sname = "onnx";
std::string jstr
= "{\"mllib\":\"tensorrt\",\"description\":\"Test gan onnx "
"import\",\"type\":\"supervised\",\"model\":{\"repository\":\""
+ cyclegan_onnx_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"360,\"width\":360},\"mllib\":{"
"\"maxBatchSize\":1,\"maxWorkspaceSize\":256,\"gpuid\":0,"
"\"datatype\":\"fp16\"}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// predict
std::string jpredictstr
= "{\"service\":\"" + sname
+ "\",\"parameters\":{\"input\":{\"height\":360,"
"\"width\":360,\"rgb\":true,\"scale\":0.00392,\"mean\":[0.5,0.5,0.5]"
",\"std\":[0.5,0.5,0.5]},\"output\":{},\"mllib\":{\"extract_layer\":"
"\"last\"}},\"data\":[\""
+ cyclegan_onnx_repo + "horse.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
JDoc jd;
// std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
ASSERT_TRUE(jd["body"]["predictions"][0]["vals"].IsArray());
ASSERT_EQ(jd["body"]["predictions"][0]["vals"].Size(), 360 * 360 * 3);

jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(cyclegan_onnx_repo + "TRTengine_bs1"));
}

0 comments on commit c93188c

Please sign in to comment.