diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index 11ca73790ea79..e485eb3181502 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -1454,12 +1454,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return Resolve(default_options);
}
+ ///
+ /// This function converts all the graph TensorProto initializers into OrtValues
+ /// and creates a in-memory external data reference for each OrtValue.
+ ///
+ ///
+ Status ConvertInitializersIntoOrtValues();
+
/**
- * @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto.
- *
- * This function converts specified TensorProto initializers in the graph into OrtValues and
- * creates in-memory external data references for each OrtValue. It then updates the provided
- * GraphProto with the modified initializers.
+ * @brief This function examines the specified initializers in the graph and converts them inline
+ * if any has external data in memory.
*
* @param iterators Span of iterators pointing to the initializers and the order that should be processed
* @param output_graph_proto The GraphProto to be updated with the modified initializers
@@ -1633,17 +1637,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
/// Status indicating success or failure
Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
- ///
- /// This function replaces all of the initializers within output_graph_proto
- /// from this Graph instance. All in memory initializers are regenerated and inlined.
- /// This is necessary even if the graph_proto_ is already up to date because initializers() may
- /// contain obsolete initializers that are no longer in use due to optimizations and contain obsolete
- /// references to OrtValues that may no longer be around (since we like appending rather than replacing).
- ///
- /// Destination GraphProto to receive the updated initializers.
- /// Status indicating success or failure.
- Status RegenerateInitializersAndReplaceInMemory(ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
-
///
/// This function traverses the graph bottom up and externalizes
/// constant initializers along with their pre-packed blobs from different
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 3d67314cf693a..8b599dc86d997 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -1231,28 +1231,6 @@ Graph::Graph(const Model& owning_model,
ArgNameToTypeMap name_to_type_map;
const auto& model_path = ModelPath();
- // If the tensor proto data is large enough, move data from TensorProto to an OrtValue
- // - Add external data reference to TensorProto that points to an OrtValue.
- // This lambda should not be used on initializers that already have external data reference.
- // Otherwise, this function does nothing.
- auto put_large_tensor_in_ort_value = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) {
- size_t size_in_bytes = 0;
- ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
- if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
- OrtValue ort_value;
- ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
- CPUAllocator::DefaultInstance(), ort_value));
- constexpr const bool use_tensor_buffer_true = true;
- auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(),
- use_tensor_buffer_true);
- assert(ort_value.IsAllocated());
- auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value));
- ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(),
- " in the initializer list.");
- tensor_proto = std::move(tensor_proto_to_add);
- }
- };
-
// Process 'Constant' nodes
// Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list
for (auto& node : graph_proto_->node()) {
@@ -1272,8 +1250,6 @@ Graph::Graph(const Model& owning_model,
}
}
- put_large_tensor_in_ort_value(*tensor);
-
// Ensure initializers are also graph inputs.
if (ir_version_ < 4) {
TypeProto t{utils::TypeProtoFromTensorProto(*tensor)};
@@ -1350,25 +1326,7 @@ Graph::Graph(const Model& owning_model,
}
// Copy initial tensors to a map.
- for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) {
- auto& tensor = *graph_proto_->mutable_initializer(i);
- // If data is on disk, it will be loaded either by optimizers
- // or during session state finalization.
- // If data is already in memory, do nothing.
- if (!utils::HasExternalData(tensor)) {
- // sparse_tensor_names_ contain references to strings to save memory
- // in case we replace the tensor_proto, we want to make sure we remove
- // the old reference first, and then add a new one.
- const bool is_sparse = sparse_tensor_names_.count(tensor.name());
- if (is_sparse) {
- sparse_tensor_names_.erase(tensor.name());
- }
- put_large_tensor_in_ort_value(tensor);
- if (is_sparse) {
- sparse_tensor_names_.emplace(tensor.name());
- }
- }
-
+ for (auto& tensor : graph_proto_->initializer()) {
auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor);
if (!p.second) {
LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name()
@@ -3457,6 +3415,38 @@ Status Graph::Resolve(const ResolveOptions& options) {
return ForThisAndAllSubgraphs(all_subgraphs, finalize_func);
}
+Status Graph::ConvertInitializersIntoOrtValues() {
+ std::vector all_subgraphs;
+ FindAllSubgraphs(all_subgraphs);
+
+ auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
+ // if we have any initializers that are not in memory, put them there.
+ const auto& model_path = graph.ModelPath();
+ auto& graph_proto = *graph.graph_proto_;
+ for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
+ auto& tensor_proto = *graph_proto.mutable_initializer(i);
+ if (utils::HasExternalData(tensor_proto)) {
+ continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize
+ }
+
+ size_t size_in_bytes = 0;
+ ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
+ if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
+ OrtValue ort_value;
+ ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
+ CPUAllocator::DefaultInstance(), ort_value));
+ constexpr const bool use_tensor_buffer_true = true;
+ auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(),
+ use_tensor_buffer_true);
+ ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
+ }
+ }
+ return Status::OK();
+ };
+
+ return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func);
+}
+
void Graph::SetName(const std::string& name) {
graph_proto_->set_name(name);
}
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
index e2a8005aba1da..c3b87ff1d64cb 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
@@ -1654,11 +1654,8 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t
SetAllGraphInputs(graph_build);
}
- auto status = graph_build.Resolve();
- if (!status.IsOK()) {
- LOGS_DEFAULT(ERROR) << status.ErrorMessage();
- ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX graph resolve failed: " + status.ErrorMessage()));
- }
+ ORT_THROW_IF_ERROR(graph_build.Resolve());
+
// Add parent graph output to the subgraph
int i = 0;
std::vector subgraph_outputs;
@@ -1705,41 +1702,38 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t
auto model = graph_viewer->CreateModel(*GetLogger());
auto model_proto = model->ToProto();
- // ORT's default topological sort is using reversed DFS.
- // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
- // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
- // the model proto that has different node ordering compared to original onnx model.
-
// save user provided external data in memory instead of writing to ModelProto
// needed for models > 2GB
std::vector userWeights;
if (use_external_data_initializer_) {
- auto c_api = Ort::GetApi();
- const InitializedTensorSet& allInitializers = graph_viewer->GetAllInitializedTensors();
+ const auto& allInitializers = graph_viewer->GetAllInitializedTensors();
userWeights.reserve(allInitializers.size());
- for (auto& entry : allInitializers) {
- OrtValue initializer_value;
- auto* tp = entry.second;
+ for (const auto& [name, tp] : allInitializers) {
if (utils::HasRawData(*tp)) {
- userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size()));
- } else if (graph_viewer->GetOrtValueInitializer(tp->name(), initializer_value)) {
- // the initializer was marked as external data by the ORT graph at load time since it was provided in memory
- size_t size = 0;
- const void* ptr = nullptr;
- Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size));
- Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr));
- userWeights.emplace_back(tp->name(), ptr, size);
+ // Keep inits in memory instead of writing to ModelProto.
+ userWeights.emplace_back(name, tp->raw_data().data(), tp->raw_data().size());
} else if (utils::HasExternalDataInMemory(*tp)) {
- // only copy and take ownership of the data if none of the above conditions are met
- std::unique_ptr full_init;
- ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
- userWeights.emplace_back(std::move(full_init->name()), std::move(full_init->raw_data()));
+ // the initializer was marked as external data by the ORT graph at load time since it was provided in memory
+ if (OrtValue v; graph_viewer->GetOrtValueInitializer(name, v)) {
+ Ort::ConstValue initializer_value{&v};
+ const size_t size = initializer_value.GetTensorSizeInBytes();
+ const void* ptr = initializer_value.GetTensorRawData();
+ userWeights.emplace_back(name, ptr, size);
+ } else {
+ // only copy and take ownership of the data if none of the above conditions are met
+ std::unique_ptr full_init;
+ ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
+ userWeights.emplace_back(name, full_init->raw_data());
+ }
}
}
}
+ // ORT's default topological sort is using reversed DFS.
+ // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
+ // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
+ // the model proto that has different node ordering compared to original onnx model.
graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/);
-
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string string_buf;
@@ -2567,30 +2561,27 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
// exclude weights if external
std::vector userWeights;
if (use_external_data_initializer_) {
- auto c_api = Ort::GetApi();
const InitializedTensorSet& allInitializers = graph_body_viewer.GetAllInitializedTensors();
userWeights.reserve(allInitializers.size());
- for (auto& entry : allInitializers) {
- OrtValue initializer_value;
- auto* tp = entry.second;
+ for (const auto& [name, tp] : allInitializers) {
if (utils::HasRawData(*tp)) {
- userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size()));
- } else if (graph_body_viewer.GetOrtValueInitializer(tp->name(), initializer_value)) {
- // the initializer was marked as external data by the ORT graph at load time since it was provided in memory
- size_t size = 0;
- const void* ptr = nullptr;
- Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size));
- Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr));
- userWeights.emplace_back(tp->name(), ptr, size);
+ userWeights.emplace_back(name, tp->raw_data().data(), tp->raw_data().size());
} else if (utils::HasExternalDataInMemory(*tp)) {
- // only copy and take ownership of the data if none of the above conditions are met
- std::unique_ptr full_init;
- ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
- userWeights.emplace_back(TensorrtUserWeights(std::move(full_init->name()), std::move(full_init->raw_data())));
+ // the initializer was marked as external data by the ORT graph at load time since it was provided in memory
+ if (OrtValue v; graph_body_viewer.GetOrtValueInitializer(name, v)) {
+ Ort::ConstValue initializer_value{&v};
+ const size_t size = initializer_value.GetTensorSizeInBytes();
+ const void* ptr = initializer_value.GetTensorRawData();
+ userWeights.emplace_back(name, ptr, size);
+ } else {
+ // only copy and take ownership of the data if none of the above conditions are met
+ std::unique_ptr full_init;
+ ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
+ userWeights.emplace_back(name, full_init->raw_data());
+ }
}
}
}
-
// ORT's default topological sort is using reversed DFS.
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h
index f1d545d0c6b17..6bb468435e47c 100644
--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h
+++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h
@@ -70,13 +70,27 @@ struct IteratorHolder {
bool operator!=(const IteratorHolder& p) const { return p_->operator!=(*p.p_); }
void operator++() { p_->operator++(); }
- const TResult& operator*() { return p_->operator*(); }
+ TResult& operator*() { return p_->operator*(); }
T* operator->() { return p_.get(); }
private:
std::unique_ptr p_;
};
+struct TensorProto_ConstIterator {
+ virtual ~TensorProto_ConstIterator() = default;
+ virtual bool operator!=(const TensorProto_ConstIterator& p) const = 0;
+ virtual void operator++() = 0;
+ virtual const ONNX_NAMESPACE::TensorProto& operator*() const = 0;
+};
+
+struct TensorProto_Iterator {
+ virtual ~TensorProto_Iterator() = default;
+ virtual bool operator!=(const TensorProto_Iterator& p) const = 0;
+ virtual void operator++() = 0;
+ virtual ONNX_NAMESPACE::TensorProto& operator*() const = 0;
+};
+
struct NodeAttributes_Iterator {
virtual ~NodeAttributes_Iterator() {}
@@ -439,7 +453,8 @@ struct ProviderHost {
// GraphProto
virtual std::unique_ptr GraphProto__construct() = 0;
virtual void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) = 0;
- virtual void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0;
+ virtual ONNX_NAMESPACE::GraphProto& GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0;
+ virtual ONNX_NAMESPACE::GraphProto& GraphProto__operator_move_assign(ONNX_NAMESPACE::GraphProto* p, ONNX_NAMESPACE::GraphProto&& v) = 0;
virtual const ONNX_NAMESPACE::ValueInfoProto& GraphProto__input(const ONNX_NAMESPACE::GraphProto* p, int index) = 0;
virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_input(ONNX_NAMESPACE::GraphProto* p) = 0;
@@ -492,7 +507,8 @@ struct ProviderHost {
// TensorProto
virtual std::unique_ptr TensorProto__construct() = 0;
virtual void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) = 0;
- virtual void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0;
+ virtual ONNX_NAMESPACE::TensorProto& TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0;
+ virtual ONNX_NAMESPACE::TensorProto& TensorProto__operator_move_assign(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto&& v) = 0;
virtual bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) = 0;
virtual void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) = 0;
virtual const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) = 0;
@@ -521,8 +537,12 @@ struct ProviderHost {
// TensorProtos
virtual ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) = 0;
- virtual int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) = 0;
+ virtual int TensorProtos__size(const ONNX_NAMESPACE::TensorProtos* p) = 0;
virtual ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) = 0;
+ virtual std::unique_ptr TensorProtos__begin(const ONNX_NAMESPACE::TensorProtos* p) = 0;
+ virtual std::unique_ptr TensorProtos__end(const ONNX_NAMESPACE::TensorProtos* p) = 0;
+ virtual std::unique_ptr TensorProtos__begin(ONNX_NAMESPACE::TensorProtos* p) = 0;
+ virtual std::unique_ptr TensorProtos__end(ONNX_NAMESPACE::TensorProtos* p) = 0;
// TensorShapeProto_Dimension
virtual int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0;
diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
index c7400c276f912..d3584d12df235 100644
--- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
+++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
@@ -175,7 +175,8 @@ struct AttributeProto final {
struct GraphProto final {
static std::unique_ptr Create() { return g_host->GraphProto__construct(); }
static void operator delete(void* p) { g_host->GraphProto__operator_delete(reinterpret_cast(p)); }
- void operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); }
+ GraphProto& operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); }
+ GraphProto& operator=(GraphProto&& v) noexcept { return g_host->GraphProto__operator_move_assign(this, std::move(v)); }
const ValueInfoProto& input(int index) const { return g_host->GraphProto__input(this, index); }
ValueInfoProtos* mutable_input() { return g_host->GraphProto__mutable_input(this); }
@@ -241,7 +242,10 @@ struct NodeProto final {
struct TensorProto final {
static std::unique_ptr Create() { return g_host->TensorProto__construct(); }
static void operator delete(void* p) { g_host->TensorProto__operator_delete(reinterpret_cast(p)); }
- void operator=(const TensorProto& v) { g_host->TensorProto__operator_assign(this, v); }
+ TensorProto& operator=(const TensorProto& v) {
+ return g_host->TensorProto__operator_assign(this, v);
+ }
+ TensorProto& operator=(TensorProto&& v) noexcept { return g_host->TensorProto__operator_move_assign(this, std::move(v)); }
bool has_name() const { return g_host->TensorProto__has_name(this); }
void set_name(const ::std::string& name) { return g_host->TensorProto__set_name(this, name); }
@@ -283,8 +287,12 @@ struct TensorProto final {
struct TensorProtos final {
TensorProto* Add() { return g_host->TensorProtos__Add(this); }
- int size() { return g_host->TensorProtos__size(this); }
+ int size() const { return g_host->TensorProtos__size(this); }
TensorProto& at(int index) { return g_host->TensorProtos__at(this, index); }
+ IteratorHolder begin() const { return g_host->TensorProtos__begin(this); }
+ IteratorHolder end() const { return g_host->TensorProtos__end(this); }
+ IteratorHolder begin() { return g_host->TensorProtos__begin(this); }
+ IteratorHolder end() { return g_host->TensorProtos__end(this); }
PROVIDER_DISALLOW_ALL(TensorProtos)
};
@@ -935,9 +943,9 @@ struct NodeAttributes final {
ONNX_NAMESPACE::AttributeProto& operator[](const std::string& string) { return g_host->NodeAttributes__operator_array(this, string); }
const ONNX_NAMESPACE::AttributeProto& at(const std::string& string) const { return g_host->NodeAttributes__at(this, string); }
- IteratorHolder> begin() const { return g_host->NodeAttributes__begin(this); }
- IteratorHolder> end() const { return g_host->NodeAttributes__end(this); }
- IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); }
+ IteratorHolder> begin() const { return g_host->NodeAttributes__begin(this); }
+ IteratorHolder> end() const { return g_host->NodeAttributes__end(this); }
+ IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); }
void insert(const NodeAttributes& v) { return g_host->NodeAttributes__insert(this, v); }
void emplace(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__emplace(this, k, v); }
void emplace(const std::string& k, ONNX_NAMESPACE::AttributeProto&& v) { g_host->NodeAttributes__emplace(this, k, std::move(v)); }
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index cd0c0e4bffdb5..0f281cfb272a0 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -2280,7 +2280,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
SetAllGraphInputs(graph_build);
}
- ORT_ENFORCE(graph_build.Resolve().IsOK());
+ ORT_THROW_IF_ERROR(graph_build.Resolve());
// Add parent graph output to the subgraph
int i = 0;
@@ -2295,7 +2295,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
auto& graph_build_outputs = graph_build.GetOutputs();
subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end());
graph_build.SetOutputs(graph_build_outputs);
- ORT_ENFORCE(graph_build.Resolve().IsOK());
+ ORT_THROW_IF_ERROR(graph_build.Resolve());
// Check if input tensors have shapes
if (iterations > 1) {
@@ -2332,27 +2332,25 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
// When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index.
// The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating
// the model proto that has different node ordering compared to original onnx model.
- // Save Initializer Data.
- std::vector userWeights;
+ auto graph_proto = ONNX_NAMESPACE::GraphProto::Create();
+ graph_viewer->ToProto(*graph_proto, true, true, 1 /*priority-based topological sort*/, !load_user_initializer_ /*include_initializer_data*/);
- // Keep inits in memory instead of writing to ModelProto.
+ // Save Initializer Data.
+ std::vector userWeights;
if (load_user_initializer_) {
- auto allInitializers = graph_viewer->GetAllInitializedTensors();
-
- for (auto& entry : allInitializers) {
- auto* tp = entry.second;
+ const auto& allInitializers = graph_viewer->GetAllInitializedTensors();
+ for (const auto& [name, tp] : allInitializers) {
if (tp->has_raw_data()) {
- userWeights.emplace_back(tp->name(), tp->raw_data());
+ userWeights.emplace_back(name, tp->raw_data());
} else if (utils::HasExternalDataInMemory(*tp)) {
std::unique_ptr full_init;
ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
- userWeights.emplace_back(full_init->name(), full_init->raw_data());
+ userWeights.emplace_back(name, full_init->raw_data());
}
}
}
-
- graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !load_user_initializer_ /*include_initializer_data*/);
+ *model_proto->mutable_graph() = std::move(*graph_proto);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string string_buf;
@@ -3098,22 +3096,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
auto model = graph_body_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();
+ // Note, wrapping std::vector into a smart ptr is redundant as the vector is a smart ptr in a sense.
auto userWeights = std::make_unique>();
-
if (load_user_initializer_) {
- auto allInitializers = graph_body_viewer.GetAllInitializedTensors();
-
- for (auto& entry : allInitializers) {
- auto name = entry.first;
- auto* tp = entry.second;
- if (tp->has_raw_data()) {
- userWeights->emplace_back(
- TensorrtUserWeights(tp->name(), tp->raw_data()));
+ const auto& allInitializers = graph_body_viewer.GetAllInitializedTensors();
+ for (const auto& [name, tp] : allInitializers) {
+ if (utils::HasRawData(*tp)) {
+ userWeights->emplace_back(name, tp->raw_data());
} else if (utils::HasExternalDataInMemory(*tp)) {
std::unique_ptr full_init;
ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
- userWeights->emplace_back(
- TensorrtUserWeights(full_init->name(), full_init->raw_data()));
+ userWeights->emplace_back(name, full_init->raw_data());
}
}
}
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index ab3932e7abfb4..6cde27ac1a6ae 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -1320,6 +1320,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
*session_logger_));
}
+ // We choose to convert initializers into OrtValues before partitioning here so plug-in EPs could
+ // take advantage of the initializers being in OrtValue format and not to deal with protobuf.
+ //
+ // The initializers data is transferred to an OrtValue. The original TensorProto is replaced
+ // with a TensorProto that has the same data type, shape and name. However, its external data
+ // is used in a non-standard way. The location is set to a string constant utils::kTensorProtoMemoryAddressTag,
+ // The file offset is set to the address of the OrtValue's data buffer, and the length is set to the size of the
+ // OrtValue's data buffer. Because this external location is non-standard, onnx code can not handle it. For this reason,
+ // we do not convert them at the graph constructor because Node::ToProto() reconstructs Graph instances for subgraphs
+ // and we do not want to have initializers converted at shape inference time, as Resolve() is called from EPs when
+ // op_types are not assigned yet.
+ //
+ // If any transformations are applied later, they would not introduce any in-memory initializers,
+ // type and shape inference would run only on any newly added nodes and any new initializers
+ // will be converted at session finalization time.
+ //
+ // The conversion is performed using the following steps (within ConvertInitializersIntoOrtValues())
+ // constexpr const bool use_tensor_buffer_true = true;
+ // auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(),
+ // use_tensor_buffer_true);
+ // ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
+ ORT_RETURN_IF_ERROR_SESSIONID_(graph.ConvertInitializersIntoOrtValues());
+
auto apply_transformer_once = [](const GraphTransformer& transformer, const logging::Logger& logger,
Graph& graph, bool* is_graph_modified = nullptr) -> onnxruntime::common::Status {
bool modified = false;
@@ -2515,6 +2538,12 @@ common::Status InferenceSession::Initialize() {
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
});
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), MakeString("Exception during initialization: ", ex.what()));
+ LOGS(*session_logger_, ERROR) << status.ErrorMessage();
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Exception during initialization: ", ex.what());
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index 19c636ba6aff1..e01bad4acff76 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -168,6 +168,25 @@ struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Ite
google::protobuf::internal::RepeatedPtrIterator v_;
};
+struct TensorProto_ConstIterator_Impl : TensorProto_ConstIterator {
+ explicit TensorProto_ConstIterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {}
+
+ bool operator!=(const TensorProto_ConstIterator& p) const override { return v_ != static_cast(&p)->v_; }
+
+ void operator++() override { v_.operator++(); }
+ const ONNX_NAMESPACE::TensorProto& operator*() const override { return *v_; }
+
+ google::protobuf::internal::RepeatedPtrIterator v_;
+};
+
+struct TensorProto_Iterator_Impl : TensorProto_Iterator {
+ explicit TensorProto_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {}
+ bool operator!=(const TensorProto_Iterator& p) const override { return v_ != reinterpret_cast(&p)->v_; }
+ void operator++() override { v_.operator++(); }
+ ONNX_NAMESPACE::TensorProto& operator*() const override { return *v_; }
+ google::protobuf::internal::RepeatedPtrIterator v_;
+};
+
struct NodeAttributes_Iterator_Impl : NodeAttributes_Iterator {
NodeAttributes_Iterator_Impl(NodeAttributes::const_iterator&& v) : v_{std::move(v)} {}
@@ -594,7 +613,14 @@ struct ProviderHostImpl : ProviderHost {
std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_name(); }
ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); }
- void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; }
+ ONNX_NAMESPACE::GraphProto& GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override {
+ *p = v;
+ return *p;
+ }
+ ONNX_NAMESPACE::GraphProto& GraphProto__operator_move_assign(ONNX_NAMESPACE::GraphProto* p, ONNX_NAMESPACE::GraphProto&& v) override {
+ *p = std::move(v);
+ return *p;
+ }
void GraphProto__set_name(ONNX_NAMESPACE::GraphProto* p, const std::string& name) override { p->set_name(name); }
void GraphProto__set_doc_string(ONNX_NAMESPACE::GraphProto* p, const std::string& doc_str) override {
@@ -633,7 +659,14 @@ struct ProviderHostImpl : ProviderHost {
// TensorProto (wrapped)
std::unique_ptr TensorProto__construct() override { return std::make_unique(); }
void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) override { delete p; }
- void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { *p = v; }
+ ONNX_NAMESPACE::TensorProto& TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override {
+ *p = v;
+ return *p;
+ }
+ ONNX_NAMESPACE::TensorProto& TensorProto__operator_move_assign(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto&& v) override {
+ *p = std::move(v);
+ return *p;
+ }
bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_name(); }
void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) override { p->set_name(name); }
const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) override { return p->name(); }
@@ -663,8 +696,20 @@ struct ProviderHostImpl : ProviderHost {
// TensorProtos (wrapped)
ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); }
- int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); }
+ int TensorProtos__size(const ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); }
ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) override { return p->at(index); };
+ std::unique_ptr TensorProtos__begin(const ONNX_NAMESPACE::TensorProtos* p) override {
+ return std::make_unique(p->begin());
+ }
+ std::unique_ptr TensorProtos__end(const ONNX_NAMESPACE::TensorProtos* p) override {
+ return std::make_unique(p->end());
+ }
+ std::unique_ptr TensorProtos__begin(ONNX_NAMESPACE::TensorProtos* p) override {
+ return std::make_unique(p->begin());
+ }
+ std::unique_ptr TensorProtos__end(ONNX_NAMESPACE::TensorProtos* p) override {
+ return std::make_unique(p->end());
+ }
// TensorShapeProto_Dimension (wrapped)
int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->value_case(); }
diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc
index 6a3f2f974b9f5..4d80cb704748c 100644
--- a/onnxruntime/test/ir/graph_test.cc
+++ b/onnxruntime/test/ir/graph_test.cc
@@ -2813,8 +2813,15 @@ TEST_F(GraphTest, ShapeInferenceAfterInitializerExternalization) {
ASSERT_TRUE(graph.GetInitializedTensor("split_sizes", initializer_after));
ASSERT_NE(initializer_after, nullptr);
// Debug: verify it was externalized
+ ASSERT_FALSE(utils::HasExternalDataInMemory(*initializer_after))
+ << "We no longer externalize data in the Graph constructor.";
+
+ // Now externalize explicitly to trigger the bug scenario
+ ASSERT_STATUS_OK(graph.ConvertInitializersIntoOrtValues());
+ ASSERT_TRUE(graph.GetInitializedTensor("split_sizes", initializer_after));
+ ASSERT_NE(initializer_after, nullptr);
ASSERT_TRUE(utils::HasExternalDataInMemory(*initializer_after))
- << "Initializer was not externalized to in-memory external data";
+ << "The initializer should externalize now";
// Mark the graph as needing resolve to force shape inference to run again
graph.SetGraphResolveNeeded();
diff --git a/onnxruntime/test/ir/utils_test.cc b/onnxruntime/test/ir/utils_test.cc
index e9744ccacbdd5..ae212a726cf4c 100644
--- a/onnxruntime/test/ir/utils_test.cc
+++ b/onnxruntime/test/ir/utils_test.cc
@@ -7,6 +7,7 @@
#include "core/graph/model.h"
#include "test/test_environment.h"
+#include "test/util/include/asserts.h"
using ONNX_NAMESPACE::Utils::DataTypeUtils;
using namespace ONNX_NAMESPACE;
@@ -178,8 +179,7 @@ static void CreateNodeRemovalGraph(Model& model, bool removal_allowed, bool test
if_node.AddAttribute("then_branch", then_branch);
if_node.AddAttribute("else_branch", else_branch);
- auto status = graph.Resolve();
- ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
+ ASSERT_STATUS_OK(graph.Resolve());
}
static void CheckNodeRemovalSubgraphUpdate(const std::string& new_name, const Graph& subgraph) {