Skip to content

Commit

Permalink
[Testing][Models] Add gpt2 module in testing models (#252)
Browse files Browse the repository at this point in the history
Added gpt2 to hidet.testing.models.gpt2, and implemented the version that supports both initial generation and key-value cache.
Enhancement:
- Added hook support for compiled model, used to investigate the execution of the compiled mode (mainly for debug).
- Fixed a bug in memory planner.
  • Loading branch information
yaoyaoding committed May 29, 2023
1 parent 13217c7 commit 6d4bd3d
Show file tree
Hide file tree
Showing 33 changed files with 759 additions and 212 deletions.
13 changes: 13 additions & 0 deletions include/hidet/runtime/memory_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ static int64_t memory_planner_allocate(int64_t size) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();

// auto ret = memory_planner.regions.begin()->start;
// memory_planner.regions.begin()->start += size;
// return ret;
if(size == 0) {
return -1;
}

size = (size + 127) / 128 * 128; // ceil to 128 bytes
for (auto it = memory_planner.regions.begin(); it != memory_planner.regions.end(); ++it) {
if (it->size >= size) {
auto region = *it;
Expand All @@ -59,6 +68,10 @@ static void memory_planner_free(int64_t ptr) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();
if(ptr == -1) {
return;
}

int64_t start = ptr;
int64_t size = memory_planner.size_map[ptr];
auto it = memory_planner.regions.begin();
Expand Down
16 changes: 10 additions & 6 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,12 @@ def param_declare(self, v: Var):

def local_var_declare(self, v: Var):
v_type = v.type
name_doc = self(v)
if isinstance(v_type, DataType):
dtype_doc = self(v_type)
name_doc = self(v)
return dtype_doc + ' ' + name_doc
elif isinstance(v_type, TensorType):
dtype_doc = self(v_type.dtype)
name_doc = self(v)
shape_doc = Doc()
for s in v_type.shape:
shape_doc += '[' + self(s) + ']'
Expand All @@ -155,20 +154,25 @@ def local_var_declare(self, v: Var):
else:
attr_doc = Doc()
base_type_doc = self(v_type.base_type)
name_doc = self(v)
if v_type.use_bracket:
return attr_doc + base_type_doc + ' ' + name_doc + '[]'
else:
return attr_doc + base_type_doc + ' *' + name_doc
elif isinstance(v_type, TensorPointerType):
dtype_doc = self(v_type.tensor_type.dtype)
name_doc = self(v)
return dtype_doc + ' *' + name_doc
elif isinstance(v_type, FuncType):
return_type_doc = self(v_type.ret_type)
name_doc = self(v)
args_doc = doc_join([self(param_type) for param_type in v_type.param_types], sep=', ')
return return_type_doc + ' (*' + name_doc + ')(' + args_doc + ')'
elif isinstance(v_type, ArrayType):
if isinstance(v_type.base_type, FuncType):
return_type_doc = self(v_type.base_type.ret_type)
args_doc = doc_join([self(param_type) for param_type in v_type.base_type.param_types], sep=', ')
return return_type_doc + ' (*' + name_doc + '[' + self(v_type.size) + '])(' + args_doc + ')'
else:
base_type_doc = self(v_type.base_type)
return base_type_doc + ' ' + name_doc + '[' + self(v_type.size) + ']'
else:
assert False

Expand Down Expand Up @@ -203,7 +207,7 @@ def visit_IRModule(self, module: IRModule) -> Doc:
for name, var in module.global_vars.items():
if name in module.functions:
continue
doc += self.param_declare(var) + ';' + NewLine()
doc += self.local_var_declare(var) + ';' + NewLine()

# define functions
call_graph = CallGraph(module)
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/ffi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def from_param(obj):
return ctypes.cast(char_array, ctypes.c_void_p)
elif isinstance(obj, str):
return ctypes.c_char_p(obj.encode('utf-8'))
elif isinstance(obj, ctypes.c_void_p):
return obj
else:
raise ValueError(f"Argument type '{type(obj)}' can not converted to a pointer.")

Expand Down
51 changes: 23 additions & 28 deletions python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import hidet.graph.operator
import hidet.cuda
from hidet import option
from hidet.ir.expr import Var
from hidet.ir.expr import is_constant
from hidet.ir.task import Task
from hidet.graph.tensor import Tensor, zeros_like, randn_like
from hidet.graph.operator import Operator, SymbolVar
Expand All @@ -35,12 +35,10 @@ def before_graph(self, graph: FlowGraph, inputs: List[Tensor]) -> None:
def after_graph(self, graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
pass

def before_operator(self, op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int]) -> None:
def before_operator(self, op: Operator, inputs: List[Tensor]) -> None:
pass

def after_operator(
self, op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int], outputs: List[Tensor]
) -> None:
def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tensor]) -> None:
pass


Expand All @@ -64,30 +62,28 @@ def current() -> GraphForwardContext:
return GraphForwardContext._stack[-1]

@staticmethod
def before_graph(graph: FlowGraph, inputs: List[Tensor]) -> None:
def _before_graph(graph: FlowGraph, inputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.before_graph(graph, inputs)

@staticmethod
def after_graph(graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
def _after_graph(graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.after_graph(graph, inputs, outputs)

@staticmethod
def before_operator(op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int]) -> None:
def _before_operator(op: Operator, inputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.before_operator(op, inputs, shape_map)
instrument.before_operator(op, inputs)

@staticmethod
def after_operator(
op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int], outputs: List[Tensor]
) -> None:
def _after_operator(op: Operator, inputs: List[Tensor], outputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.after_operator(op, inputs, shape_map, outputs)
instrument.after_operator(op, inputs, outputs)

def append_instrument(self, instrument: GraphForwardInstrument):
self.instruments.append(instrument)
Expand Down Expand Up @@ -213,20 +209,26 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:

# set the symbol values
for expect_input, actual_input in zip(self.inputs, inputs):
if expect_input.device != actual_input.device:
raise ValueError(
'Expect input {} to have device {}, got {}.'.format(
expect_input, expect_input.device, actual_input.device
)
)
for expect_dim, actual_dim in zip(expect_input.shape, actual_input.shape):
if isinstance(expect_dim, SymbolVar):
runtime_api.set_symbol_value(expect_dim.name, int(actual_dim))
else:
assert is_constant(actual_dim, expect_dim) and expect_dim == actual_dim

GraphForwardContext.before_graph(self, inputs)
GraphForwardContext._before_graph(self, inputs)

# count the usage of each tensor. We use this count to determine whether
# a tensor should be freed after running an operator.
usage_count = self.usage_count.copy()
tensor_map: Dict[Tensor, Tensor] = {} # symbolic tensor -> actual tensor during the forward process
shape_map: Dict[SymbolVar, int] = {} # symbolic dimension -> actual shape dimension for the symbolic tensors
for st, at in zip(self.inputs, inputs):
tensor_map[st] = at
shape_map.update({dim: at.shape[idx] for idx, dim in enumerate(st.shape) if isinstance(dim, Var)})

# run each operator in the graph in a topological order
for idx, node in enumerate(self.nodes):
Expand All @@ -246,23 +248,16 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
node_inputs = node_inputs[: len(node.inputs)]

# run node
GraphForwardContext.before_operator(node, node_inputs, shape_map)
GraphForwardContext._before_operator(node, node_inputs)
logger.debug('[%4d/%d] run operator %s, %s', idx, len(self.nodes), node.name, node.task)
logger.debug(' inputs: %s', [x.signature() for x in node_inputs])
node_outputs = node.imperative_run(node_inputs)
logger.debug(' outputs: %s', [x.signature() for x in node_outputs])
GraphForwardContext.after_operator(node, node_inputs, shape_map, node_outputs)
GraphForwardContext._after_operator(node, node_inputs, node_outputs)

# update map
for node_output, symbolic_output in zip(node_outputs, node.outputs):
tensor_map[symbolic_output] = node_output
shape_map.update(
{
dim: node_output.shape[idx]
for idx, dim in enumerate(symbolic_output.shape)
if isinstance(dim, Var)
}
)

outputs = []
for graph_output in self.outputs:
Expand All @@ -273,7 +268,7 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
else:
raise RuntimeError('Graph output {} is not produced by any operator.'.format(graph_output.signature()))

GraphForwardContext.after_graph(self, inputs, outputs)
GraphForwardContext._after_graph(self, inputs, outputs)
return outputs

def dummy_inputs(self) -> List[Tensor]:
Expand Down Expand Up @@ -348,7 +343,7 @@ def update_nodes(self):
self.inputs = free_vars
return self

def build(self):
def build(self, allow_hook=False):
"""
Build the flow graph to a compiled model (hidet.runtime.CompiledModel).
Expand All @@ -359,7 +354,7 @@ def build(self):
"""
from hidet.graph.graph_utils.build import flow_graph_build

return flow_graph_build(self)
return flow_graph_build(self, allow_hook=allow_hook)

def cuda_graph(self):
"""Create a CudaGraph from FlowGraph.
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/frontend/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,8 +1128,8 @@ def __init__(self, graph: onnx.GraphProto, op_sets: List[int], env_tensors: Opti
self.name: str = graph.name
for param in graph.initializer:
numpy_array = onnx.numpy_helper.to_array(tensor=param)
self.parameters[param.name] = from_numpy(numpy_array).cuda()
self.input_names: List[str] = [input.name for input in graph.input if input.name not in self.parameters]
self._parameters[param.name] = from_numpy(numpy_array).cuda()
self.input_names: List[str] = [input.name for input in graph.input if input.name not in self._parameters]
self.output_names: List[str] = [output.name for output in graph.output]
self.operators: List[OnnxOperator] = dispatch_operators(graph.node, op_sets)
# self.operators: List[OnnxOperator] = [dispatch(node, op_sets=self.op_sets) for node in graph.node]
Expand All @@ -1142,7 +1142,7 @@ def forward(self, *args):
name2tensor.update(self.env_tensors)
assert len(args) == len(self.input_names)
# parameters
for name, param in self.parameters.items():
for name, param in self._parameters.items():
name2tensor[name] = param
# inputs
for name, inp in zip(self.input_names, args):
Expand Down
5 changes: 2 additions & 3 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,8 @@ def ones(

@register_function(torch.nn.functional.gelu)
def gelu(x: Tensor, approximate: Optional[str] = "none"):
if approximate is not None and approximate != "none":
warnings.warn_once("hidet: gelu with approximate {repr(approximate)} is not supported. Treat as 'none'.")
return ops.gelu(x)
approximate = {"none": False, "tanh": True}[approximate]
return ops.gelu(x, approximate=approximate)


@register_function(torch.nn.functional.layer_norm)
Expand Down
60 changes: 54 additions & 6 deletions python/hidet/graph/graph_utils/build.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Set, Dict

import hidet
from hidet.ir.type import FuncType, void, byte_p
from hidet.ir.type import FuncType, void, byte_p, func_type
from hidet.ir.expr import SymbolVar, Var, Expr, var
from hidet.ir.stmt import AssignStmt, DeclareStmt
from hidet.ir.stmt import AssignStmt, DeclareStmt, BufferStoreStmt
from hidet.graph.tensor import Tensor
from hidet.graph.flow_graph import FlowGraph
from hidet.runtime.module import CompiledModule
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_graph_meta_data(graph: FlowGraph) -> ModelMetaData:
)


def flow_graph_build(graph) -> CompiledModel:
def flow_graph_build(graph, allow_hook=False) -> CompiledModel:
from hidet.lang import void_p, attrs, int32, int64, meta, cast
from hidet.ir.primitives.runtime import memory_planner_init, memory_planner_allocate, memory_planner_free
from hidet.ir.primitives.runtime import memory_planner_used
Expand All @@ -79,10 +79,12 @@ def flow_graph_build(graph) -> CompiledModel:
workspace = var('workspace', byte_p)
weights = var('weights', void_p[len(graph_weights)])
kernels = var('kernels', void_p[len(graph_nodes)])
exec_hook = var('exec_hook', func_type([~int64], void))

script_module.define_global_var(workspace)
script_module.define_global_var(weights)
script_module.define_global_var(kernels)
script_module.define_global_var(exec_hook)

@hidet.script
def init(num_kernels: int, p_kernels: ~void_p, num_weights: int, p_weights: ~void_p):
Expand Down Expand Up @@ -133,6 +135,48 @@ def set_workspace(space: void_p):

AssignStmt(workspace, space)

@hidet.script
def register_hook(hook: void_p):
attrs.func_kind = 'public'

assert allow_hook, "Hook is not allowed when building the graph"
nonlocal exec_hook
exec_hook = hook

def call_exec_hook(idx: int, node_params: List[Expr]):
sb = hidet.ir.builders.StmtBuilder()

with sb.if_then(exec_hook != 0):
args = []
args.append(idx) # kernel index

tensors: List[Tensor]
if idx < len(graph_nodes):
args.append(len(graph_nodes[idx].inputs))
args.append(len(graph_nodes[idx].outputs))
tensors = graph_nodes[idx].inputs + graph_nodes[idx].outputs
else:
args.append(len(graph.inputs))
args.append(len(graph.outputs))
tensors = graph.inputs + graph.outputs

assert len(tensors) == len(node_params), "Expect {} parameters, got {}".format(
len(tensors), len(node_params)
)

for tensor, param in zip(tensors, node_params):
args.append(tensor.dtype.name)
args.append(len(tensor.shape))
args.extend([cast(d, 'int64') for d in tensor.shape])
args.append(param)

args_var = var('args', int64[len(args)])
sb += DeclareStmt(args_var)
for i in range(len(args)):
sb += BufferStoreStmt(args_var, [i], args[i])
sb += exec_hook(args_var)
return sb.finish()

def launch_impl(inputs: List[Var], outputs: List[Var]):
sb = hidet.ir.builders.StmtBuilder()
usage_count = graph.usage_count
Expand Down Expand Up @@ -160,15 +204,19 @@ def launch_impl(inputs: List[Var], outputs: List[Var]):
else:
raise RuntimeError("Unknown tensor {}".format(y))

func_type = FuncType([void_p for _ in node_params], void)
kernel_var = var("kernel_{}".format(idx), func_type)
with sb.let(kernel_var, cast(kernels[idx], func_type)):
kernel_type = FuncType([void_p for _ in node_params], void)
kernel_var = var("k{}_{}".format(idx, graph_nodes[idx].name), kernel_type)
with sb.let(kernel_var, cast(kernels[idx], kernel_type)):
sb += kernel_var(*node_params)
if allow_hook:
sb += call_exec_hook(idx, node_params)

for x in node.inputs:
usage_count[x] -= 1
if usage_count[x] == 0 and x in graph_intermediates:
sb += memory_planner_free(tensor_ptr[x])
if allow_hook:
sb += call_exec_hook(len(graph_nodes), inputs + outputs)
return sb.finish()

@hidet.script
Expand Down

0 comments on commit 6d4bd3d

Please sign in to comment.