Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document the Graph header files and cleanup some issues. #42

Merged
merged 8 commits into from Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/onnxruntime/core/graph/function.h
Expand Up @@ -13,17 +13,29 @@ class Node;

namespace onnxruntime {

// Function representation class.
/**
@class Function
Class representing a Function.
*/
class Function {
public:
virtual ~Function() {}

/** Gets the OpSchema for the Function. */
virtual const ONNX_NAMESPACE::OpSchema& OpSchema() const = 0;

/** Gets the Graph instance for the Function body subgraph. */
virtual const onnxruntime::Graph& Body() const = 0;

/** Gets the IndexedSubGraph for the Function. */
virtual const IndexedSubGraph& GetIndexedSubGraph() const = 0;
};

/**
Create a new Function instance.
@param graph The graph containing the Function.
@param customized_func the IndexedSubGraph to use for the Function.
*/
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
std::unique_ptr<IndexedSubGraph> customized_func);
} // namespace onnxruntime
63 changes: 54 additions & 9 deletions include/onnxruntime/core/graph/graph.h
Expand Up @@ -11,53 +11,98 @@ struct IndexedSubGraph;
} // namespace onnxruntime

namespace onnxruntime {
// A graph viewer representation class.

/**
@class GraphViewer
Class that provides a read-only view of the Graph.
@remarks If the underlying Graph is changed, GetNodesInTopologicalOrder and GetRootNodes may become invalid.
*/
class GraphViewer {
public:
/**
Construct a GraphViewer from the provided Graph instance.
*/
GraphViewer(const Graph& graph);

// Graph name.
/** Gets the Graph name. */
const std::string& Name() const noexcept;

/** Gets the Graph description. */
const std::string& Description() const noexcept;

/**
Gets a tensor created from an initializer.
@param tensor_name The tensor name
@param[out] value Sets the pointer to the TensorProto if found, or nullptr if not.
@returns True if found. False if not.
*/
bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;

// Graph inputs excluding initializers.
/**
Gets the Graph inputs, excluding initializers.
@returns Collection of NodeArg pointers for the graph inputs, excluding inputs that have matching initializers.
@remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto.
*/
const std::vector<const NodeArg*>& GetInputs() const noexcept;
// Graph inputs including initializers. Contains no nullptr values.
// This will match the number and order of inputs from the GraphProto.

/**
Gets the Graph inputs, including any initializers.
@returns Collection of NodeArg pointers for all the graph inputs.
@remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto.
*/
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept;

// Graph outputs. Should have no nullptr values.
/**
Gets the Graph outputs.
@returns Collection of NodeArg pointers for all the graph outputs.
@remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto.
*/
const std::vector<const NodeArg*>& GetOutputs() const noexcept;

// Get graph value infos.
/** Gets all ValueInfo NodeArg instances in the Graph. */
const std::vector<const NodeArg*>& GetValueInfo() const noexcept;

// Get const Node given specific node index. May return nullptr if node as been freed.
/**
Gets the Node instance at the specified index.
@param node_index Index to retrieve Node from.
@remarks May return nullptr if index no longer points to a valid node due to the node being freed.
*/
const Node* GetNode(NodeIndex node_index) const;

/** Gets an iterator over all the valid Nodes in the Graph. */
const GraphNodes& Nodes() const noexcept;

/** Gets the number of valid nodes in the Graph. */
int NumberOfNodes() const noexcept;

/** Gets the maximum NodeIndex value used by Nodes in the Graph. */
int MaxNodeIndex() const noexcept;

/** Gets the NodeIndex values for the Graph nodes, sorted into topological order. */
const std::vector<NodeIndex>& GetNodesInTopologicalOrder() const;

/**
Gets the NodeIndex values for the root nodes in the Graph.
The root nodes are the topmost nodes in the Graph that receive inputs from the Graph inputs
and no other nodes in the Graph.
*/
const std::vector<NodeIndex>& GetRootNodes() const;

/** Gets all tensors created from initializers. */
const InitializedTensorSet& GetAllInitializedTensors() const noexcept;

/**
Gets the NodeArg instance for the given name.
@returns A NodeArg if found, a nullptr if not.
*/
const NodeArg* GetNodeArg(const std::string& name) const;

private:
ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer);

const Graph* graph_;

// The topological order of node index.
// The NodeIndex values of the graph nodes sorted in topological order.
std::vector<NodeIndex> nodes_in_topological_order_;
// Graph root nodes.
std::vector<NodeIndex> root_nodes_;
Expand Down