Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Testing][Models] Add gpt2 module in testing models #252

Merged
merged 5 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/hidet/runtime/memory_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ static int64_t memory_planner_allocate(int64_t size) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();

// auto ret = memory_planner.regions.begin()->start;
// memory_planner.regions.begin()->start += size;
// return ret;
if(size == 0) {
return -1;
}

size = (size + 127) / 128 * 128; // ceil to 128 bytes
for (auto it = memory_planner.regions.begin(); it != memory_planner.regions.end(); ++it) {
if (it->size >= size) {
auto region = *it;
Expand All @@ -59,6 +68,10 @@ static void memory_planner_free(int64_t ptr) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();
if(ptr == -1) {
return;
}

int64_t start = ptr;
int64_t size = memory_planner.size_map[ptr];
auto it = memory_planner.regions.begin();
Expand Down
16 changes: 10 additions & 6 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,12 @@ def param_declare(self, v: Var):

def local_var_declare(self, v: Var):
v_type = v.type
name_doc = self(v)
if isinstance(v_type, DataType):
dtype_doc = self(v_type)
name_doc = self(v)
return dtype_doc + ' ' + name_doc
elif isinstance(v_type, TensorType):
dtype_doc = self(v_type.dtype)
name_doc = self(v)
shape_doc = Doc()
for s in v_type.shape:
shape_doc += '[' + self(s) + ']'
Expand All @@ -155,20 +154,25 @@ def local_var_declare(self, v: Var):
else:
attr_doc = Doc()
base_type_doc = self(v_type.base_type)
name_doc = self(v)
if v_type.use_bracket:
return attr_doc + base_type_doc + ' ' + name_doc + '[]'
else:
return attr_doc + base_type_doc + ' *' + name_doc
elif isinstance(v_type, TensorPointerType):
dtype_doc = self(v_type.tensor_type.dtype)
name_doc = self(v)
return dtype_doc + ' *' + name_doc
elif isinstance(v_type, FuncType):
return_type_doc = self(v_type.ret_type)
name_doc = self(v)
args_doc = doc_join([self(param_type) for param_type in v_type.param_types], sep=', ')
return return_type_doc + ' (*' + name_doc + ')(' + args_doc + ')'
elif isinstance(v_type, ArrayType):
if isinstance(v_type.base_type, FuncType):
return_type_doc = self(v_type.base_type.ret_type)
args_doc = doc_join([self(param_type) for param_type in v_type.base_type.param_types], sep=', ')
return return_type_doc + ' (*' + name_doc + '[' + self(v_type.size) + '])(' + args_doc + ')'
else:
base_type_doc = self(v_type.base_type)
return base_type_doc + ' ' + name_doc + '[' + self(v_type.size) + ']'
else:
assert False

Expand Down Expand Up @@ -203,7 +207,7 @@ def visit_IRModule(self, module: IRModule) -> Doc:
for name, var in module.global_vars.items():
if name in module.functions:
continue
doc += self.param_declare(var) + ';' + NewLine()
doc += self.local_var_declare(var) + ';' + NewLine()

# define functions
call_graph = CallGraph(module)
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/ffi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def from_param(obj):
return ctypes.cast(char_array, ctypes.c_void_p)
elif isinstance(obj, str):
return ctypes.c_char_p(obj.encode('utf-8'))
elif isinstance(obj, ctypes.c_void_p):
return obj
else:
raise ValueError(f"Argument type '{type(obj)}' can not converted to a pointer.")

Expand Down
51 changes: 23 additions & 28 deletions python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import hidet.graph.operator
import hidet.cuda
from hidet import option
from hidet.ir.expr import Var
from hidet.ir.expr import is_constant
from hidet.ir.task import Task
from hidet.graph.tensor import Tensor, zeros_like, randn_like
from hidet.graph.operator import Operator, SymbolVar
Expand All @@ -35,12 +35,10 @@ def before_graph(self, graph: FlowGraph, inputs: List[Tensor]) -> None:
def after_graph(self, graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
pass

def before_operator(self, op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int]) -> None:
def before_operator(self, op: Operator, inputs: List[Tensor]) -> None:
pass

def after_operator(
self, op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int], outputs: List[Tensor]
) -> None:
def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tensor]) -> None:
pass


Expand All @@ -64,30 +62,28 @@ def current() -> GraphForwardContext:
return GraphForwardContext._stack[-1]

@staticmethod
def before_graph(graph: FlowGraph, inputs: List[Tensor]) -> None:
def _before_graph(graph: FlowGraph, inputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.before_graph(graph, inputs)

@staticmethod
def after_graph(graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
def _after_graph(graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.after_graph(graph, inputs, outputs)

@staticmethod
def before_operator(op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int]) -> None:
def _before_operator(op: Operator, inputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.before_operator(op, inputs, shape_map)
instrument.before_operator(op, inputs)

@staticmethod
def after_operator(
op: Operator, inputs: List[Tensor], shape_map: Dict[SymbolVar, int], outputs: List[Tensor]
) -> None:
def _after_operator(op: Operator, inputs: List[Tensor], outputs: List[Tensor]) -> None:
ctx = GraphForwardContext.current()
for instrument in ctx.instruments:
instrument.after_operator(op, inputs, shape_map, outputs)
instrument.after_operator(op, inputs, outputs)

def append_instrument(self, instrument: GraphForwardInstrument):
self.instruments.append(instrument)
Expand Down Expand Up @@ -213,20 +209,26 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:

# set the symbol values
for expect_input, actual_input in zip(self.inputs, inputs):
if expect_input.device != actual_input.device:
raise ValueError(
'Expect input {} to have device {}, got {}.'.format(
expect_input, expect_input.device, actual_input.device
)
)
for expect_dim, actual_dim in zip(expect_input.shape, actual_input.shape):
if isinstance(expect_dim, SymbolVar):
runtime_api.set_symbol_value(expect_dim.name, int(actual_dim))
else:
assert is_constant(actual_dim, expect_dim) and expect_dim == actual_dim

GraphForwardContext.before_graph(self, inputs)
GraphForwardContext._before_graph(self, inputs)

# count the usage of each tensor. We use this count to determine whether
# a tensor should be freed after running an operator.
usage_count = self.usage_count.copy()
tensor_map: Dict[Tensor, Tensor] = {} # symbolic tensor -> actual tensor during the forward process
shape_map: Dict[SymbolVar, int] = {} # symbolic dimension -> actual shape dimension for the symbolic tensors
for st, at in zip(self.inputs, inputs):
tensor_map[st] = at
shape_map.update({dim: at.shape[idx] for idx, dim in enumerate(st.shape) if isinstance(dim, Var)})

# run each operator in the graph in a topological order
for idx, node in enumerate(self.nodes):
Expand All @@ -246,23 +248,16 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
node_inputs = node_inputs[: len(node.inputs)]

# run node
GraphForwardContext.before_operator(node, node_inputs, shape_map)
GraphForwardContext._before_operator(node, node_inputs)
logger.debug('[%4d/%d] run operator %s, %s', idx, len(self.nodes), node.name, node.task)
logger.debug(' inputs: %s', [x.signature() for x in node_inputs])
node_outputs = node.imperative_run(node_inputs)
logger.debug(' outputs: %s', [x.signature() for x in node_outputs])
GraphForwardContext.after_operator(node, node_inputs, shape_map, node_outputs)
GraphForwardContext._after_operator(node, node_inputs, node_outputs)

# update map
for node_output, symbolic_output in zip(node_outputs, node.outputs):
tensor_map[symbolic_output] = node_output
shape_map.update(
{
dim: node_output.shape[idx]
for idx, dim in enumerate(symbolic_output.shape)
if isinstance(dim, Var)
}
)

outputs = []
for graph_output in self.outputs:
Expand All @@ -273,7 +268,7 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
else:
raise RuntimeError('Graph output {} is not produced by any operator.'.format(graph_output.signature()))

GraphForwardContext.after_graph(self, inputs, outputs)
GraphForwardContext._after_graph(self, inputs, outputs)
return outputs

def dummy_inputs(self) -> List[Tensor]:
Expand Down Expand Up @@ -348,7 +343,7 @@ def update_nodes(self):
self.inputs = free_vars
return self

def build(self):
def build(self, allow_hook=False):
"""
Build the flow graph to a compiled model (hidet.runtime.CompiledModel).

Expand All @@ -359,7 +354,7 @@ def build(self):
"""
from hidet.graph.graph_utils.build import flow_graph_build

return flow_graph_build(self)
return flow_graph_build(self, allow_hook=allow_hook)

def cuda_graph(self):
"""Create a CudaGraph from FlowGraph.
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/frontend/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,8 +1128,8 @@ def __init__(self, graph: onnx.GraphProto, op_sets: List[int], env_tensors: Opti
self.name: str = graph.name
for param in graph.initializer:
numpy_array = onnx.numpy_helper.to_array(tensor=param)
self.parameters[param.name] = from_numpy(numpy_array).cuda()
self.input_names: List[str] = [input.name for input in graph.input if input.name not in self.parameters]
self._parameters[param.name] = from_numpy(numpy_array).cuda()
self.input_names: List[str] = [input.name for input in graph.input if input.name not in self._parameters]
self.output_names: List[str] = [output.name for output in graph.output]
self.operators: List[OnnxOperator] = dispatch_operators(graph.node, op_sets)
# self.operators: List[OnnxOperator] = [dispatch(node, op_sets=self.op_sets) for node in graph.node]
Expand All @@ -1142,7 +1142,7 @@ def forward(self, *args):
name2tensor.update(self.env_tensors)
assert len(args) == len(self.input_names)
# parameters
for name, param in self.parameters.items():
for name, param in self._parameters.items():
name2tensor[name] = param
# inputs
for name, inp in zip(self.input_names, args):
Expand Down
5 changes: 2 additions & 3 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,8 @@ def ones(

@register_function(torch.nn.functional.gelu)
def gelu(x: Tensor, approximate: Optional[str] = "none"):
if approximate is not None and approximate != "none":
warnings.warn_once("hidet: gelu with approximate {repr(approximate)} is not supported. Treat as 'none'.")
return ops.gelu(x)
approximate = {"none": False, "tanh": True}[approximate]
return ops.gelu(x, approximate=approximate)


@register_function(torch.nn.functional.layer_norm)
Expand Down
60 changes: 54 additions & 6 deletions python/hidet/graph/graph_utils/build.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Set, Dict

import hidet
from hidet.ir.type import FuncType, void, byte_p
from hidet.ir.type import FuncType, void, byte_p, func_type
from hidet.ir.expr import SymbolVar, Var, Expr, var
from hidet.ir.stmt import AssignStmt, DeclareStmt
from hidet.ir.stmt import AssignStmt, DeclareStmt, BufferStoreStmt
from hidet.graph.tensor import Tensor
from hidet.graph.flow_graph import FlowGraph
from hidet.runtime.module import CompiledModule
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_graph_meta_data(graph: FlowGraph) -> ModelMetaData:
)


def flow_graph_build(graph) -> CompiledModel:
def flow_graph_build(graph, allow_hook=False) -> CompiledModel:
from hidet.lang import void_p, attrs, int32, int64, meta, cast
from hidet.ir.primitives.runtime import memory_planner_init, memory_planner_allocate, memory_planner_free
from hidet.ir.primitives.runtime import memory_planner_used
Expand All @@ -79,10 +79,12 @@ def flow_graph_build(graph) -> CompiledModel:
workspace = var('workspace', byte_p)
weights = var('weights', void_p[len(graph_weights)])
kernels = var('kernels', void_p[len(graph_nodes)])
exec_hook = var('exec_hook', func_type([~int64], void))

script_module.define_global_var(workspace)
script_module.define_global_var(weights)
script_module.define_global_var(kernels)
script_module.define_global_var(exec_hook)

@hidet.script
def init(num_kernels: int, p_kernels: ~void_p, num_weights: int, p_weights: ~void_p):
Expand Down Expand Up @@ -133,6 +135,48 @@ def set_workspace(space: void_p):

AssignStmt(workspace, space)

@hidet.script
def register_hook(hook: void_p):
attrs.func_kind = 'public'

assert allow_hook, "Hook is not allowed when building the graph"
nonlocal exec_hook
exec_hook = hook

def call_exec_hook(idx: int, node_params: List[Expr]):
sb = hidet.ir.builders.StmtBuilder()

with sb.if_then(exec_hook != 0):
args = []
args.append(idx) # kernel index

tensors: List[Tensor]
if idx < len(graph_nodes):
args.append(len(graph_nodes[idx].inputs))
args.append(len(graph_nodes[idx].outputs))
tensors = graph_nodes[idx].inputs + graph_nodes[idx].outputs
else:
args.append(len(graph.inputs))
args.append(len(graph.outputs))
tensors = graph.inputs + graph.outputs

assert len(tensors) == len(node_params), "Expect {} parameters, got {}".format(
len(tensors), len(node_params)
)

for tensor, param in zip(tensors, node_params):
args.append(tensor.dtype.name)
args.append(len(tensor.shape))
args.extend([cast(d, 'int64') for d in tensor.shape])
args.append(param)

args_var = var('args', int64[len(args)])
sb += DeclareStmt(args_var)
for i in range(len(args)):
sb += BufferStoreStmt(args_var, [i], args[i])
sb += exec_hook(args_var)
return sb.finish()

def launch_impl(inputs: List[Var], outputs: List[Var]):
sb = hidet.ir.builders.StmtBuilder()
usage_count = graph.usage_count
Expand Down Expand Up @@ -160,15 +204,19 @@ def launch_impl(inputs: List[Var], outputs: List[Var]):
else:
raise RuntimeError("Unknown tensor {}".format(y))

func_type = FuncType([void_p for _ in node_params], void)
kernel_var = var("kernel_{}".format(idx), func_type)
with sb.let(kernel_var, cast(kernels[idx], func_type)):
kernel_type = FuncType([void_p for _ in node_params], void)
kernel_var = var("k{}_{}".format(idx, graph_nodes[idx].name), kernel_type)
with sb.let(kernel_var, cast(kernels[idx], kernel_type)):
sb += kernel_var(*node_params)
if allow_hook:
sb += call_exec_hook(idx, node_params)

for x in node.inputs:
usage_count[x] -= 1
if usage_count[x] == 0 and x in graph_intermediates:
sb += memory_planner_free(tensor_ptr[x])
if allow_hook:
sb += call_exec_hook(len(graph_nodes), inputs + outputs)
return sb.finish()

@hidet.script
Expand Down