Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ void TryCalculateSizeFromResolvedShape(int ml_value_idx, std::unordered_map<int,

} // namespace

// If this function fails NO memory planning will take place, hence lets ONLY FAIL and stop training where warranted, example SIZE overflow.
Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_wrapper<const TensorShape>>& input_shape,
const std::vector<int>& feed_mlvalue_idxs,
MemoryPatternGroup* output,
Expand Down Expand Up @@ -425,12 +426,17 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
auto* arg = node->OutputDefs()[i];
size_t is_resolved = 0;
std::vector<int64_t> resolved_shape;
ORT_RETURN_IF_ERROR(TryResolveShape(arg, map, is_resolved, resolved_shape));

// Store all valid resolved shapes. They will be queried in, for example,
// Recv operator to bypass the dependency of output shapes on inputs.
if (is_resolved != 0) {
resolved_shapes[ml_value_idx] = resolved_shape;
// Tensors whose shape cannot be resolved statically will be allocated at runtime.
if (TryResolveShape(arg, map, is_resolved, resolved_shape).IsOK()) {
// Store all valid resolved shapes. They will be queried in, for example,
// Recv operator to bypass the dependency of output shapes on inputs.
if (is_resolved != 0) {
resolved_shapes[ml_value_idx] = resolved_shape;
}
} else {
LOGS(logger_, INFO) << "[Static memory planning] Could not resolve shape for tensor with ML index "
<< ml_value_idx << ", will allocate dynamically.";
}
}
}
Expand Down