Skip to content

Commit

Permalink
Update APIs with model_path for ONNXRT (#621)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
  • Loading branch information
kevinch-nv committed Jan 11, 2021
1 parent 17c6d89 commit 8b5328f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 11 deletions.
14 changes: 12 additions & 2 deletions ModelImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ Status deserialize_onnx_model(int fd, bool is_serialized_as_text, ::ONNX_NAMESPA
}

bool ModelImporter::supportsModel(
void const* serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t& sub_graph_collection)
void const* serialized_onnx_model, size_t serialized_onnx_model_size, SubGraphCollection_t& sub_graph_collection,
const char* model_path)
{

::ONNX_NAMESPACE::ModelProto model;
Expand All @@ -307,6 +308,11 @@ bool ModelImporter::supportsModel(
return false;
}

if (model_path)
{
_importer_ctx.setOnnxFileLocation(model_path);
}

bool allSupported{true};

// Parse the graph and see if we hit any parsing errors
Expand Down Expand Up @@ -454,8 +460,12 @@ bool ModelImporter::parseWithWeightDescriptors(void const* serialized_onnx_model
return true;
}

bool ModelImporter::parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size)
bool ModelImporter::parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path)
{
if (model_path)
{
_importer_ctx.setOnnxFileLocation(model_path);
}
return this->parseWithWeightDescriptors(serialized_onnx_model, serialized_onnx_model_size, 0, nullptr);
}

Expand Down
4 changes: 2 additions & 2 deletions ModelImporter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class ModelImporter : public nvonnxparser::IParser
}
bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
uint32_t weight_count, onnxTensorDescriptorV1 const* weight_descriptors) override;
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override;
bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override;
bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection) override;
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override;

bool supportsOperator(const char* op_name) const override;
void destroy() override
Expand Down
11 changes: 7 additions & 4 deletions NvOnnxParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,14 @@ class IParser
* To obtain a better diagnostic, use the parseFromFile method below.
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param serialized_onnx_model_size Size of the serialized ONNX model in bytes
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parse(void const* serialized_onnx_model,
size_t serialized_onnx_model_size)
size_t serialized_onnx_model_size,
const char* model_path = nullptr)
= 0;

/** \brief Parse an onnx model file, can be a binary protobuf or a text onnx model
Expand All @@ -158,11 +159,13 @@ class IParser
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param sub_graph_collection Container to hold supported subgraphs
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model is supported
*/
virtual bool supportsModel(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection)
SubGraphCollection_t& sub_graph_collection,
const char* model_path = nullptr)
= 0;

/** \brief Parse a serialized ONNX model into the TensorRT network
Expand Down
2 changes: 1 addition & 1 deletion OnnxAttrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class OnnxAttrs
return _attrs.at(key);
}

const ::ONNX_NAMESPACE::AttributeProto::AttributeType type(const std::string& key) const
::ONNX_NAMESPACE::AttributeProto::AttributeType type(const std::string& key) const
{
return this->at(key)->type();
}
Expand Down
4 changes: 2 additions & 2 deletions onnx2trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,10 +1346,10 @@ bool parseExternalWeights(IImporterContext* ctx, std::string file, std::string p
relPathFile.seekg(offset, std::ios::beg);
int weightsBufSize = length == 0 ? fileSize : length;
weightsBuf.resize(weightsBufSize);
LOG_VERBOSE("Reading weights from external file: " << file);
LOG_VERBOSE("Reading weights from external file: " << path);
if (!relPathFile.read(weightsBuf.data(), weightsBuf.size()))
{
LOG_ERROR("Failed to read weights from external file: " << file);
LOG_ERROR("Failed to read weights from external file: " << path);
return false;
}
size = weightsBuf.size();
Expand Down

0 comments on commit 8b5328f

Please sign in to comment.