Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
6 contributors

Users who have contributed to this file

@tqchen @yinghai @srkreddy1238 @siju-samuel @wweic @ajtulloch
429 lines (407 sloc) 13.1 KB
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Tiny graph runtime that can run graph
* containing only tvm PackedFunc.
* \file graph_runtime.h
*/
#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
#include <dlpack/dlpack.h>
#include <dmlc/memory_io.h>
#include <dmlc/json.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include <string>
namespace tvm {
namespace runtime {
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
int ret = (func); \
CHECK_EQ(ret, 0) \
<< TVMGetLastError(); \
}
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*! \brief operator attributes about tvm op */
struct TVMOpParam {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
};
/*!
* \brief Tiny graph runtime.
*
* This runtime can be acccesibly in various language via
* TVM runtime PackedFunc API.
*/
class GraphRuntime : public ModuleNode {
struct OpArgs {
std::vector<DLTensor> args;
std::unordered_map<uint32_t, std::vector<uint32_t> > input_entry_ids;
std::vector<TVMValue> arg_values;
std::vector<int> arg_tcodes;
std::vector<int64_t> shape_data;
};
public:
/*!
* \brief Get member function to front-end
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
/*!
* \return The type key of the executor.
*/
const char* type_key() const final {
return "GraphRuntime";
}
void Run();
/*!
* \brief Initialize the graph executor with graph and context.
* \param graph_json The execution graph.
* \param module The module containing the compiled functions for the host
* processor.
* \param ctxs The context of the host and devices where graph nodes will be
* executed on.
*/
void Init(const std::string& graph_json,
tvm::runtime::Module module,
const std::vector<TVMContext>& ctxs);
/*!
* \brief Get the input index given the name of input.
* \param name The name of the input.
* \return The index of input.
*/
int GetInputIndex(const std::string& name);
/*!
* \brief set index-th input to the graph.
* \param index The input index.
* \param data_in The input data.
*/
void SetInput(int index, DLTensor* data_in);
/*!
* \brief set index-th input to the graph without copying the data
* \param index The input index.
* \param data_ref The input data that is referred.
*/
void SetInputZeroCopy(int index, DLTensor* data_ref);
/*!
* \brief Get the number of outputs
*
* \return The number of outputs from graph.
*/
int NumOutputs() const;
/*!
* \brief Return NDArray for given input index.
* \param index The input index.
*
* \return NDArray corresponding to given input node index.
*/
NDArray GetInput(int index) const;
/*!
* \brief Return NDArray for given output index.
* \param index The output index.
*
* \return NDArray corresponding to given output node index.
*/
NDArray GetOutput(int index) const;
/*!
* \brief Copy index-th output to data_out.
* \param index The output index.
* \param data_out the output data.
*/
void CopyOutputTo(int index, DLTensor* data_out);
/*!
* \brief Load parameters from binary stream
* \param strm The input stream.
*/
void LoadParams(dmlc::Stream* strm);
/*!
* \brief Load parameters from parameter blob.
* \param param_blob A binary blob of parameter.
*/
void LoadParams(const std::string& param_blob);
/*!
* \brief Share parameters from pre-existing GraphRuntime instance.
* \param other A GraphRuntime instance, previously with |LoadParams| called with the
* identical input |param_blob|.
* \param strm The input stream.
*/
void ShareParams(const GraphRuntime& other, dmlc::Stream* strm);
/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/
uint32_t GetNumOfNodes() const {
return static_cast<uint32_t>(nodes_.size());
}
std::string GetNodeName(uint32_t nid) const {
return nodes_[nid].name;
}
protected:
// Memory pool entry.
struct PoolEntry {
size_t size;
int device_type;
PoolEntry(int s, int dev_type) : size(s), device_type(dev_type) {}
};
// Node entry
struct NodeEntry {
uint32_t node_id;
uint32_t index;
uint32_t version;
// JSON Loader
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&index);
if (reader->NextArrayItem()) {
reader->Read(&version);
CHECK(!reader->NextArrayItem()) << "invalid json format";
} else {
version = 0;
}
}
};
// Node
struct Node {
// operator type in string
std::string op_type;
// name of the op
std::string name;
// parameters
TVMOpParam param;
// inputs
std::vector<NodeEntry> inputs;
// control deps
std::vector<uint32_t> control_deps;
// JSON Loader
void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) {
int bitmask = 0;
std::string key, value;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
reader->Read(&value);
if (key == "func_name") {
param->func_name = value;
bitmask |= 1;
} else if (key == "num_inputs") {
param->num_inputs = strtoul(value.c_str(), nullptr, 10);
bitmask |= 2;
} else if (key == "num_outputs") {
param->num_outputs = strtoul(value.c_str(), nullptr, 10);
bitmask |= 4;
} else if (key == "flatten_data") {
param->flatten_data = strtoul(value.c_str(), nullptr, 10);
bitmask |= 8;
}
}
CHECK_EQ(bitmask, 1|2|4|8) << "invalid format";
}
// JSON Loader
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "op") {
reader->Read(&op_type);
bitmask |= 1;
} else if (key == "name") {
reader->Read(&name);
bitmask |= 2;
} else if (key == "inputs") {
reader->Read(&inputs);
bitmask |= 4;
} else if (key == "attr" || key == "attrs") {
this->LoadAttrs(reader, &param);
} else if (key == "control_deps") {
reader->Read(&control_deps);
} else {
LOG(FATAL) << "do not support key " << key;
}
}
CHECK_EQ(bitmask, 1|2|4) << "invalid format";
}
};
struct GraphAttr {
size_t storage_num_not_alloctaed{0};
std::vector<int> storage_id;
std::vector<int> device_index;
std::vector<std::string> dltype;
std::vector<std::vector<int64_t> > shape;
// The graph attribute fields.
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key, type;
while (reader->NextObjectItem(&key)) {
if (key == "dltype") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_str");
CHECK(reader->NextArrayItem());
reader->Read(&dltype);
CHECK(!reader->NextArrayItem());
bitmask |= 1;
} else if (key == "storage_id") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_int");
CHECK(reader->NextArrayItem());
reader->Read(&storage_id);
CHECK(!reader->NextArrayItem());
bitmask |= 2;
} else if (key == "shape") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_shape");
CHECK(reader->NextArrayItem());
reader->Read(&shape);
CHECK(!reader->NextArrayItem());
bitmask |= 4;
} else if (key == "device_index") {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
CHECK_EQ(type, "list_int");
CHECK(reader->NextArrayItem());
reader->Read(&device_index);
CHECK(!reader->NextArrayItem());
} else {
reader->BeginArray();
CHECK(reader->NextArrayItem());
reader->Read(&type);
if (type == "list_int") {
CHECK(reader->NextArrayItem());
std::vector<int> temp;
reader->Read(&temp);
} else if (type == "size_t") {
CHECK(reader->NextArrayItem());
size_t temp;
reader->Read(&temp);
} else {
LOG(FATAL) << "cannot skip graph attr " << key;
}
CHECK(!reader->NextArrayItem());
}
}
CHECK_EQ(bitmask, 1|2|4) << "invalid format";
}
};
// The graph attribute fields.
void Load(dmlc::JSONReader *reader) {
reader->BeginObject();
int bitmask = 0;
std::string key;
while (reader->NextObjectItem(&key)) {
if (key == "nodes") {
reader->Read(&nodes_);
bitmask |= 1;
} else if (key == "arg_nodes") {
reader->Read(&input_nodes_);
bitmask |= 2;
} else if (key == "node_row_ptr") {
reader->Read(&node_row_ptr_);
bitmask |= 4;
} else if (key == "heads") {
reader->Read(&outputs_);
bitmask |= 8;
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
} else if (key == "metadata") {
break;
} else {
LOG(FATAL) << "key " << key << " is not supported";
}
}
CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format";
}
/*! \brief Setup the temporal storage */
void SetupStorage();
/*! \brief Setup the executors. */
void SetupOpExecs();
/*!
* \brief Create an execution function given input.
* \param attrs The node attributes.
* \param args The arguments to the functor, including inputs and outputs.
* \param num_inputs Number of inputs.
* \return The created executor.
*/
std::pair<std::function<void()>, std::shared_ptr<OpArgs> > CreateTVMOp(
const TVMOpParam& attrs, const std::vector<DLTensor>& args,
size_t num_inputs);
// Get node entry index.
uint32_t entry_id(uint32_t nid, uint32_t index) const {
return node_row_ptr_[nid] + index;
}
// Get node entry index.
uint32_t entry_id(const NodeEntry& e) const {
return entry_id(e.node_id, e.index);
}
// Number of node entries.
uint32_t num_node_entries() const {
return node_row_ptr_.back();
}
/*! \brief The graph nodes. */
std::vector<Node> nodes_;
/*! \brief The argument nodes. */
std::vector<uint32_t> input_nodes_;
/*! \brief Used for quick entry indexing. */
std::vector<uint32_t> node_row_ptr_;
/*! \brief Output entries. */
std::vector<NodeEntry> outputs_;
/*! \brief Additional graph attributes. */
GraphAttr attrs_;
/*! \brief The code module that contains both host and device code. */
tvm::runtime::Module module_;
/*! \brief Execution context of all devices including the host. */
std::vector<TVMContext> ctxs_;
/*! \brief Common storage pool for all devices. */
std::vector<NDArray> storage_pool_;
/*! \brief Data entry of each node. */
std::vector<NDArray> data_entry_;
/*! \brief Data alignment of each node. */
std::vector<size_t> data_alignment_;
/*! \brief Operator on each node. */
std::vector<std::function<void()> > op_execs_;
/*! \brief Arg info of TVM ops */
std::vector<std::shared_ptr<OpArgs> > op_args_;
};
std::vector<TVMContext> GetAllContext(const TVMArgs& args);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_
You can’t perform that action at this time.