Skip to content

Commit

Permalink
WebNN: Refactor GraphBuilder of DirectML backend
Browse files Browse the repository at this point in the history
The Refactoring includes:
- Introduce decoupled `InputNode` and `OperatorNode` that inherits
  `Node`. `InputNode` encapsulates graph input index while
  `OperatorNode` encapsulates operator node index and DirectML operator
  respectively. `NodeOutput` can be created for both types of node and
  connected to another `OperatorNode`
- GraphBuilder creates and maintains the `InputNode`s, `OperatorNode`s
  and `NodeOutput`s in `std::list`s. Users reference them by pointers.
  It’s users’ responsibility to ensure these pointers do not outlive the
  GraphBuilder.
- Remove `NodeOutputInfo` and `GraphBuilder::GetNodeOutput` that reduces
  one indirection for the user code.
- `GraphBuilder::CreateOperatorNode()` returns a `nullptr` when
  `IDMLDevice::CreateOperator()` fails.
- Use `base::span` to pass the `const NodeOutput*` array that avoids
  unnecessary heap allocation.


Bug: 1273291
Change-Id: I82c801c437cf3fcd44568c947d318201bedb624a
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4871184
Commit-Queue: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Cr-Commit-Position: refs/heads/main@{#1209173}
  • Loading branch information
huningxin authored and Chromium LUCI CQ committed Oct 13, 2023
1 parent 20b50c1 commit b6cb2a7
Show file tree
Hide file tree
Showing 4 changed files with 487 additions and 342 deletions.
222 changes: 144 additions & 78 deletions services/webnn/dml/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,76 @@

namespace webnn::dml {

Node::Node(Type type) : type_(type) {}
Node::~Node() = default;

Node::Type Node::GetType() const {
return type_;
}

const InputNode* Node::AsInputNode() const {
CHECK_EQ(GetType(), Node::Type::kInput);
return static_cast<const InputNode*>(this);
}

const OperatorNode* Node::AsOperatorNode() const {
CHECK_EQ(GetType(), Node::Type::kOperator);
return static_cast<const OperatorNode*>(this);
}

InputNode::InputNode(uint32_t graph_input_index)
: Node(Node::Type::kInput), graph_input_index_(graph_input_index) {}

InputNode::~InputNode() = default;

uint32_t InputNode::GetGraphInputIndex() const {
CHECK_EQ(type_, Node::Type::kInput);
return graph_input_index_;
}

OperatorNode::OperatorNode(uint32_t node_index,
ComPtr<IDMLOperator> dml_operator)
: Node(Node::Type::kOperator),
node_index_(node_index),
dml_operator_(std::move(dml_operator)) {
dml_operator_node_desc_ =
DML_OPERATOR_GRAPH_NODE_DESC{.Operator = dml_operator_.Get()};
}

OperatorNode::~OperatorNode() = default;

uint32_t OperatorNode::GetNodeIndex() const {
CHECK_EQ(type_, Node::Type::kOperator);
return node_index_;
}

const DML_OPERATOR_GRAPH_NODE_DESC& OperatorNode::GetDMLOperatorNodeDesc()
const {
CHECK_EQ(type_, Node::Type::kOperator);
return dml_operator_node_desc_;
}

NodeOutput::NodeOutput(const Node& node,
uint32_t output_index,
TensorDesc tensor_desc)
: node_(node),
output_index_(output_index),
tensor_desc_(std::move(tensor_desc)) {}

NodeOutput::~NodeOutput() = default;

const Node& NodeOutput::GetNode() const {
return node_.get();
}

uint32_t NodeOutput::GetOutputIndex() const {
return output_index_;
}

const TensorDesc& NodeOutput::GetTensorDesc() const {
return tensor_desc_;
}

GraphBuilder::GraphBuilder(ComPtr<IDMLDevice> dml_device)
: dml_device_(std::move(dml_device)) {}

Expand All @@ -21,90 +91,78 @@ GraphBuilder& GraphBuilder::operator=(GraphBuilder&& other) = default;

GraphBuilder::~GraphBuilder() = default;

NodeInfo GraphBuilder::CreateInputNode() {
// The input index should increase from 0 as the input is added.
return {NodeInfo::Type::kInput, input_count_++};
}

const NodeOutput& GraphBuilder::GetNodeOutput(
const NodeOutputInfo& node_output_info) const {
CHECK_LT(node_output_info.index,
base::checked_cast<uint32_t>(node_outputs_.size()));
return node_outputs_[node_output_info.index];
const InputNode* GraphBuilder::CreateInputNode() {
const uint32_t graph_input_index =
base::checked_cast<uint32_t>(input_nodes_.size());
input_nodes_.emplace_back(graph_input_index);
return &input_nodes_.back();
}

NodeInfo GraphBuilder::CreateOperatorNode(
const OperatorNode* GraphBuilder::CreateOperatorNode(
DML_OPERATOR_TYPE type,
const void* operator_desc,
const std::vector<NodeOutputInfo>& node_output_infos) {
DML_OPERATOR_DESC op_desc = {type, operator_desc};
Microsoft::WRL::ComPtr<IDMLOperator> dml_operator;
HRESULT hr =
dml_device_->CreateOperator(&op_desc, IID_PPV_ARGS(&dml_operator));
if (FAILED(hr)) {
DLOG(ERROR) << "Failed to create dml operator : "
<< logging::SystemErrorCodeToString(hr);
return {NodeInfo::Type::kInvalid, 0};
}

// Create the operator node. The node index is increased as the operator node
// is added.
uint32_t index = base::checked_cast<uint32_t>(dml_operators_.size());
NodeInfo node_info = {NodeInfo::Type::kOperator, index};

dml_operators_.push_back(std::move(dml_operator));
DML_OPERATOR_GRAPH_NODE_DESC dml_node_desc = {
.Operator = dml_operators_.back().Get()};
dml_nodes_.push_back(std::move(dml_node_desc));

// Connect multiple node outputs to one node to create the input edges and
// intermediate edges.
for (uint32_t input_index = 0;
input_index < base::checked_cast<uint32_t>(node_output_infos.size());
++input_index) {
NodeOutput node_output = GetNodeOutput(node_output_infos[input_index]);
NodeInfo from_node_info = node_output.node_info;
if (from_node_info.type == NodeInfo::Type::kInput) {
DML_INPUT_GRAPH_EDGE_DESC input_edge{
.GraphInputIndex = from_node_info.index,
.ToNodeIndex = node_info.index,
.ToNodeInputIndex = input_index};

dml_input_edges_.push_back(std::move(input_edge));
} else if (from_node_info.type == NodeInfo::Type::kOperator) {
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediate_edge{
.FromNodeIndex = from_node_info.index,
.FromNodeOutputIndex = node_output.output_index,
.ToNodeIndex = node_info.index,
.ToNodeInputIndex = input_index};

dml_intermediate_edges_.push_back(std::move(intermediate_edge));
} else {
NOTREACHED_NORETURN();
base::span<const NodeOutput*> inputs) {
DML_OPERATOR_DESC op_desc{.Type = type, .Desc = operator_desc};
ComPtr<IDMLOperator> dml_operator;
RETURN_NULL_IF_FAILED(
dml_device_->CreateOperator(&op_desc, IID_PPV_ARGS(&dml_operator)));

uint32_t operator_node_index =
base::checked_cast<uint32_t>(operator_nodes_.size());
operator_nodes_.emplace_back(operator_node_index, std::move(dml_operator));
const OperatorNode* operator_node = &operator_nodes_.back();

// Connect input node outputs to this operator node that creates the input
// edges and intermediate edges.
for (uint32_t node_input_index = 0;
node_input_index < base::checked_cast<uint32_t>(inputs.size());
++node_input_index) {
const NodeOutput* operator_input = inputs[node_input_index];
CHECK(operator_input);
const Node& from_node = operator_input->GetNode();
switch (from_node.GetType()) {
case Node::Type::kInput: {
const InputNode* from_input_node = from_node.AsInputNode();
DML_INPUT_GRAPH_EDGE_DESC input_edge{
.GraphInputIndex = from_input_node->GetGraphInputIndex(),
.ToNodeIndex = operator_node->GetNodeIndex(),
.ToNodeInputIndex = node_input_index};
dml_input_edges_.push_back(std::move(input_edge));
break;
}
case Node::Type::kOperator: {
const OperatorNode* from_operator_node = from_node.AsOperatorNode();
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediate_edge{
.FromNodeIndex = from_operator_node->GetNodeIndex(),
.FromNodeOutputIndex = operator_input->GetOutputIndex(),
.ToNodeIndex = operator_node->GetNodeIndex(),
.ToNodeInputIndex = node_input_index};
dml_intermediate_edges_.push_back(std::move(intermediate_edge));
break;
}
}
}

return node_info;
return operator_node;
}

NodeOutputInfo GraphBuilder::CreateNodeOutput(const NodeInfo& node_info,
TensorDesc tensor,
uint32_t output_index) {
CHECK_NE(node_info.type, NodeInfo::Type::kInvalid);
node_outputs_.push_back(
NodeOutput{node_info, output_index, std::move(tensor)});
// The node output index is increased as the node output is added.
return {base::checked_cast<uint32_t>(node_outputs_.size() - 1)};
const NodeOutput* GraphBuilder::CreateNodeOutput(const Node* node,
TensorDesc tensor_desc,
uint32_t output_index) {
CHECK(node);
node_outputs_.emplace_back(*node, output_index, std::move(tensor_desc));
return &node_outputs_.back();
}

uint32_t GraphBuilder::CreateOutputEdge(
const NodeOutputInfo& node_output_info) {
NodeOutput node_output = GetNodeOutput(node_output_info);
uint32_t GraphBuilder::CreateOutputEdge(const NodeOutput* node_output) {
CHECK(node_output);
const OperatorNode* from_operator_node =
node_output->GetNode().AsOperatorNode();
uint32_t graph_output_index =
base::checked_cast<uint32_t>(dml_output_edges_.size());
DML_OUTPUT_GRAPH_EDGE_DESC output_edge = {
.FromNodeIndex = node_output.node_info.index,
.FromNodeOutputIndex = node_output.output_index,
.FromNodeIndex = from_operator_node->GetNodeIndex(),
.FromNodeOutputIndex = node_output->GetOutputIndex(),
.GraphOutputIndex = graph_output_index};
dml_output_edges_.push_back(std::move(output_edge));
return graph_output_index;
Expand All @@ -113,30 +171,38 @@ uint32_t GraphBuilder::CreateOutputEdge(
ComPtr<IDMLCompiledOperator> GraphBuilder::Compile(
DML_EXECUTION_FLAGS flags) const {
TRACE_EVENT0("gpu", "dml::GraphBuilder::Compile");
std::vector<DML_GRAPH_NODE_DESC> dml_nodes(dml_nodes_.size());
for (size_t i = 0; i < dml_nodes.size(); ++i) {
dml_nodes[i] = {DML_GRAPH_NODE_TYPE_OPERATOR, &dml_nodes_[i]};
// Ensure `dml_nodes` vector is ordered by node index of operator node.
std::vector<DML_GRAPH_NODE_DESC> dml_nodes(operator_nodes_.size());
for (const auto& operator_node : operator_nodes_) {
uint32_t node_index = operator_node.GetNodeIndex();
CHECK_LT(node_index, dml_nodes.size());
dml_nodes[node_index] =
DML_GRAPH_NODE_DESC{.Type = DML_GRAPH_NODE_TYPE_OPERATOR,
.Desc = &operator_node.GetDMLOperatorNodeDesc()};
}

std::vector<DML_GRAPH_EDGE_DESC> dml_input_edges(dml_input_edges_.size());
for (size_t i = 0; i < dml_input_edges.size(); ++i) {
dml_input_edges[i] = {DML_GRAPH_EDGE_TYPE_INPUT, &dml_input_edges_[i]};
dml_input_edges[i] = DML_GRAPH_EDGE_DESC{.Type = DML_GRAPH_EDGE_TYPE_INPUT,
.Desc = &dml_input_edges_[i]};
}

std::vector<DML_GRAPH_EDGE_DESC> dml_intermediate_edges(
dml_intermediate_edges_.size());
for (size_t i = 0; i < dml_intermediate_edges.size(); ++i) {
dml_intermediate_edges[i] = {DML_GRAPH_EDGE_TYPE_INTERMEDIATE,
&dml_intermediate_edges_[i]};
dml_intermediate_edges[i] =
DML_GRAPH_EDGE_DESC{.Type = DML_GRAPH_EDGE_TYPE_INTERMEDIATE,
.Desc = &dml_intermediate_edges_[i]};
}

std::vector<DML_GRAPH_EDGE_DESC> dml_output_edges(dml_output_edges_.size());
for (size_t i = 0; i < dml_output_edges.size(); ++i) {
dml_output_edges[i] = {DML_GRAPH_EDGE_TYPE_OUTPUT, &dml_output_edges_[i]};
dml_output_edges[i] = DML_GRAPH_EDGE_DESC{
.Type = DML_GRAPH_EDGE_TYPE_OUTPUT, .Desc = &dml_output_edges_[i]};
}

DML_GRAPH_DESC dml_graph_desc = {
.InputCount = input_count_,
.InputCount = base::checked_cast<uint32_t>(input_nodes_.size()),
.OutputCount = base::checked_cast<uint32_t>(dml_output_edges_.size()),
.NodeCount = base::checked_cast<uint32_t>(dml_nodes.size()),
.Nodes = dml_nodes.data(),
Expand Down

0 comments on commit b6cb2a7

Please sign in to comment.