Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Enhance EP context configs in session options and provider options #19154

Merged
merged 26 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
*
*/
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model.
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.

Check warning on line 69 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L69

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:69:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data

Check warning on line 70 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L70

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:70:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,16 +262,22 @@
// engine cache path to be the relative path like "../file_path" or the absolute path.
// It only allows the engine cache to be in the same directory or sub directory of the context model.
if (IsAbsolutePath(cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path);

Check warning on line 265 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L265

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:265:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
if (IsRelativePathToParentPath(cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory.");

Check warning on line 268 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L268

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:268:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

// The engine cache and context model (current model) should be in the same directory
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
auto engine_cache_path = ctx_model_dir.append(cache_path);

if (!std::filesystem::exists(engine_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP can't find engine cache: " + engine_cache_path.string() +
". Please make sure engine cache is in the same directory or sub-directory of context model.");

Check warning on line 278 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L278

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:278:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in);
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
Expand All @@ -281,8 +287,7 @@
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
if (!(*trt_engine_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string() +
". Please make sure engine cache is inside the directory of trt_ep_context_file_path.");
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string());
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();
}
Expand All @@ -302,7 +307,7 @@
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
if (model_compute_capability != compute_capability_) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal";

Check warning on line 310 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L310

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:310:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability;
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_;
}
Expand Down
44 changes: 27 additions & 17 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1459,17 +1459,17 @@
force_timing_cache_match_ = (std::stoi(timing_force_match_env) == 0 ? false : true);
}

const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel);

Check warning on line 1462 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1462

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1462:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!dump_ep_context_model_env.empty()) {
dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true);
}

const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable);

Check warning on line 1467 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1467

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1467:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!ep_context_file_path_env.empty()) {
ep_context_file_path_ = ep_context_file_path_env;
}

const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode);

Check warning on line 1472 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1472

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1472:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!ep_context_embed_mode_env.empty()) {
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
}
Expand Down Expand Up @@ -1578,6 +1578,13 @@
dla_core_ = 0;
}

// If ep_context_file_path_ is provided as a directory, create it if it's not existed
if (dump_ep_context_model_ && !ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) {

Check warning on line 1582 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1582

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1582:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!std::filesystem::create_directory(ep_context_file_path_)) {
throw std::runtime_error("Failed to create directory " + ep_context_file_path_);
}
}

// If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_.
// For example,
// - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir"
Expand All @@ -1586,23 +1593,16 @@
// For security reason, it needs to make sure the engine cache is saved inside context model directory.
if (dump_ep_context_model_ && engine_cache_enable_) {
if (IsAbsolutePath(cache_path_)) {
LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " << cache_path_;

Check warning on line 1596 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1596

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1596:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
if (IsRelativePathToParentPath(cache_path_)) {
LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory.";

Check warning on line 1599 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1599

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1599:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

// Make cache_path_ to be the relative path of ep_context_file_path_
cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string();
}

// If ep_context_file_path_ is provided as a directory, create it if it's not existed
if (!ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) {
if (!std::filesystem::create_directory(ep_context_file_path_)) {
throw std::runtime_error("Failed to create directory " + ep_context_file_path_);
}
}

if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
if (!cache_path_.empty() && !fs::is_directory(cache_path_)) {
if (!fs::create_directory(cache_path_)) {
Expand Down Expand Up @@ -2335,6 +2335,14 @@
// Construct subgraph capability from node list
std::vector<std::unique_ptr<ComputeCapability>> result;

// Get ModelPath
const auto& path_string = graph.ModelPath().ToPathString();
#ifdef _WIN32
wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_));
#else
strcpy(model_path_, path_string.c_str());

Check warning on line 2343 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2343

Almost always, snprintf is better than strcpy [runtime/printf] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2343:  Almost always, snprintf is better than strcpy  [runtime/printf] [4]
#endif

// If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and
// load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation.
// So, simply return the ComputeCapability here.
Expand All @@ -2345,14 +2353,6 @@
return result;
}

// Get ModelPath
const auto& path_string = graph.ModelPath().ToPathString();
#ifdef _WIN32
wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_));
#else
strcpy(model_path_, path_string.c_str());
#endif

// Generate unique kernel name for TRT graph
HashValue model_hash = TRTGenerateId(graph);

Expand Down Expand Up @@ -2869,7 +2869,7 @@
}

// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity

Check warning on line 2872 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2872

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2872:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
const std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
Expand Down Expand Up @@ -3016,9 +3016,15 @@
}
// dump EP context node model
if (dump_ep_context_model_) {

Check warning on line 3019 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3019

Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3019:  Redundant blank line at the start of a code block should be deleted.  [whitespace/blank_line] [2]
// "ep_cache_context" node attribute should be a relative path to context model directory
if (ep_cache_context_attr_.empty()) {
ep_cache_context_attr_ = std::filesystem::relative(engine_cache_path, ep_context_file_path_).string();
}

std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto{CreateCtxModel(graph_body_viewer,
engine_cache_path,
ep_cache_context_attr_,
reinterpret_cast<char*>(serialized_engine->data()),

Check warning on line 3027 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3027

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3027:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
serialized_engine->size(),
ep_context_embed_mode_,
compute_capability_,
Expand Down Expand Up @@ -3083,8 +3089,12 @@
// TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model.
// However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here.
if (dump_ep_context_model_ && has_dynamic_shape) {
// "ep_cache_context" node attribute should be a relative path to context model directory
if (ep_cache_context_attr_.empty()) {
ep_cache_context_attr_ = std::filesystem::relative(engine_cache_path, ep_context_file_path_).string();
}
model_proto_.reset(CreateCtxModel(graph_body_viewer,
engine_cache_path,
ep_cache_context_attr_,
nullptr,
0,
ep_context_embed_mode_,
Expand Down Expand Up @@ -3411,7 +3421,7 @@

// dump ep context model
if (dump_ep_context_model_ && ep_context_embed_mode_) {
UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast<char*>(serialized_engine->data()), serialized_engine->size());

Check warning on line 3424 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3424

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3424:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
DumpCtxModel(model_proto_.get(), ctx_model_path_);
}
context_update = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool force_timing_cache_match_ = false;
bool detailed_build_log_ = false;
bool cuda_graph_enable_ = false;
std::string ctx_model_path_;
std::string cache_prefix_;

// The OrtAllocator object will be get during ep compute time
Expand All @@ -304,6 +303,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool dump_ep_context_model_ = false;
std::string ep_context_file_path_;
int ep_context_embed_mode_ = 0;
std::string ctx_model_path_;
std::string ep_cache_context_attr_;
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
Expand Down
67 changes: 62 additions & 5 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,16 @@
std::vector<int64_t> expected_dims_mul_m = {1, 3, 2};
std::vector<float> expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};

// Dump context model with specific name
/*
* Test case 1: Dump context model
*
* provider options=>
* trt_ep_context_file_path = "EP_Context_model.onnx"
*
* expected result =>
* context model "EP_Context_model.onnx" should be created in current directory
*
*/
OrtTensorRTProviderOptionsV2 params;
params.trt_engine_cache_enable = 1;
params.trt_dump_ep_context_model = 1;
Expand All @@ -405,23 +414,34 @@
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); // "EP_Context_model.onnx" should be created
ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path));

// Dump context model to specific path
/*
* Test case 2: Dump context model
*
* provider options=>
* trt_engine_cache_prefix = "TRT_engine_cache"
* trt_ep_context_file_path = "context_model_folder"
* trt_engine_cache_path = "engine_cache_folder"
*
* expected result =>
* engine cache "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created
* context model "./context_model_folder/EPContextNode_test_ctx.onnx" should be created
*/
InferenceSession session_object2{so, GetEnvironment()};
OrtTensorRTProviderOptionsV2 params2;
params2.trt_engine_cache_enable = 1;
params2.trt_dump_ep_context_model = 1;
params2.trt_engine_cache_prefix = "TRT_engine_cache";
params2.trt_engine_cache_path = "engine_cache_folder"; // due to dump_ep_context_model = 1, the new cache path is ./context_model_folder/engine_cache_folder

Check warning on line 436 in onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc#L436

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc:436:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
params2.trt_ep_context_file_path = "./context_model_folder";
params2.trt_ep_context_file_path = "context_model_folder";
execution_provider = TensorrtExecutionProviderWithOptions(&params2);
EXPECT_TRUE(session_object2.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
status = session_object2.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object2.Initialize();
ASSERT_TRUE(status.IsOK());
auto new_engine_cache_path = std::filesystem::path(params2.trt_ep_context_file_path).append(params2.trt_engine_cache_path).string();

Check warning on line 444 in onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc#L444

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc:444:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// Test engine cache path:
// "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created
ASSERT_TRUE(HasCacheFileWithPrefix(params2.trt_engine_cache_prefix, new_engine_cache_path));
Expand All @@ -429,7 +449,16 @@
// "./context_model_folder/EPContextNode_test_ctx.onnx" should be created
ASSERT_TRUE(HasCacheFileWithPrefix("EPContextNode_test_ctx.onnx", params2.trt_ep_context_file_path));

// Context model inference
/*
* Test case 3: Run the dumped context model
*
* context model path = "./EP_Context_model.onnx" (created from case 1)
*
* expected result=>
* engine cache is also in the same current dirctory as "./xxxxx.engine"
* and the "ep_cache_context" attribute node of the context model should point to that.
*
*/
InferenceSession session_object3{so, GetEnvironment()};
OrtTensorRTProviderOptionsV2 params3;
model_name = params.trt_ep_context_file_path;
Expand All @@ -447,6 +476,34 @@
// Y: 1, 3, 3, 2, 2, 2
// Z: 1, 3, 3, 2, 2, 2
RunSession(session_object3, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);

/*
* Test case 4: Run the dumped context model
*
* context model path = "./context_model_folder/EPContextNode_test_ctx.onnx" (created from case 2)
*
* expected result=>
* engine cache path is "./context_model_folder/engine_cache_folder/xxxxx.engine"
* and the "ep_cache_context" attribute node of the context model should point to that.
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
*
*/
InferenceSession session_object4{so, GetEnvironment()};
OrtTensorRTProviderOptionsV2 params4;
model_name = "./context_model_folder/EPContextNode_test_ctx.onnx";
execution_provider = TensorrtExecutionProviderWithOptions(&params4);
EXPECT_TRUE(session_object4.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
status = session_object4.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object4.Initialize();
ASSERT_TRUE(status.IsOK());
// run inference
// TRT engine will be created and cached
// TRT profile will be created and cached only for dynamic input shape
// Data in profile,
// X: 1, 3, 3, 2, 2, 2
// Y: 1, 3, 3, 2, 2, 2
// Z: 1, 3, 3, 2, 2, 2
RunSession(session_object4, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
Expand Down
Loading