Skip to content

Commit

Permalink
[IR][Pass] Refactor the fusion implementation (#164)
Browse files Browse the repository at this point in the history
* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .
  • Loading branch information
yaoyaoding committed Apr 7, 2023
1 parent 6289f46 commit 3cc75b6
Show file tree
Hide file tree
Showing 46 changed files with 752 additions and 839 deletions.
4 changes: 4 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ BUILDDIR = build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

clean:
@$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
rm -rf ./source/gallery

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import hidet
hidet.option.cache_dir(os.path.join(hidet.option.get_cache_dir(), 'docs-cache'))
hidet.utils.hidet_clear_op_cache()
print('Build docs with under cache: {}'.format(hidet.option.get_cache_dir()))

# -- Project information -----------------------------------------------------
Expand Down
6 changes: 1 addition & 5 deletions gallery/developer-guides/hidet-script-dynamic-kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


def matmul_simt_kernel():
from hidet.transforms.tools import add_packed_func
from hidet.lang import attr
from hidet.lang import float32, int32
from hidet.lang import as_tensor_pointer, tensor
Expand Down Expand Up @@ -134,10 +133,7 @@ def matmul_kernel(
assert isinstance(matmul_kernel, hidet.ir.Function) # matmul is a hidet.ir.Function

ir_module = script_module.ir_module()
add_packed_func(ir_module, matmul_kernel, pack_func_name='matmul')
compiled_function: hidet.runtime.CompiledFunction = hidet.driver.build_ir_module(
ir_module, func_name='matmul'
)
compiled_function: hidet.runtime.CompiledFunction = hidet.driver.build_ir_module(ir_module)
return compiled_function


Expand Down
4 changes: 0 additions & 4 deletions gallery/how-to-guides/add-new-operator-template-based.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def batch_matmul_mma_fp16_schedule(task: BatchMatmulFp16Task) -> IRModule:
from hidet.lang.mapping import repeat, spatial
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.lang.cuda import MmaConfig, mma_sync
from hidet.transforms.tools import add_packed_func

# get the workload size
bs = task.attributes['batch_size']
Expand Down Expand Up @@ -205,9 +204,6 @@ def batch_matmul_kernel(
store_c(regs_c, c)

ir_module = module.ir_module()
# conduct the fusion (when the task has prologue or epilogue) and generate the packed function
# ir_module = fuse_and_pack(ir_module, kernel_func=batch_matmul_kernel, task=task)
add_packed_func(ir_module, func=batch_matmul_kernel, pack_func_name=task.name)
return ir_module


Expand Down
2 changes: 1 addition & 1 deletion python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def load_task_func(lib_path: str, task) -> CompiledFunction:
print("Removed the file '{}'".format(lib_path))
os.remove(lib_path)
raise e
func_name = 'hidet_{}'.format(task.name)
func_name = 'hidet_launch'
param_types = [param.type for param in task.parameters]
packed_func = PackedFunc(param_types=param_types, c_func_pointer=lib[func_name])

Expand Down
40 changes: 15 additions & 25 deletions python/hidet/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def build_task(task: Task, target_device='cuda', load=True) -> Optional[Compiled
with open(os.path.join(task_dir, 'task.txt'), 'w') as f:
f.write(task_string)
# implement task
ir_module = task.implement(target=target_device, workding_dir=task_dir)
ir_module = task.implement(target=target_device, working_dir=task_dir)
# lower ir module
if option.get_option('save_lower_ir'):
instruments = [
Expand Down Expand Up @@ -139,7 +139,6 @@ def build_task_batch(tasks: List[Task], target_device: str = 'cuda', raise_on_er

def build_ir_module(
ir_module: IRModule,
func_name: str,
output_dir='./outs/ir_module',
save_ir: bool = True,
profile_pass: bool = True,
Expand All @@ -153,14 +152,6 @@ def build_ir_module(
src_path = os.path.join(output_dir, 'source.cu')
lib_path = os.path.join(output_dir, 'lib.so')

# get function type
func: Function = ir_module.lookup(func_name)
if func.kind == 'packed_func':
packed_func = ir_module.lookup(func.attrs['packed_func'])
func_type = FuncType.from_func(packed_func)
else:
func_type = FuncType.from_func(func)

# lower ir module
instruments = []
if save_ir:
Expand All @@ -170,6 +161,11 @@ def build_ir_module(
with PassContext(instruments=instruments):
ir_module = lower(ir_module)

# get function type
func: Function = ir_module.lookup('launch')
kernel_func = ir_module.lookup(func.attrs['packed_func'])
func_type = FuncType.from_func(kernel_func)

# code generation
codegen(ir_module, src_out_path=src_path)

Expand All @@ -178,18 +174,16 @@ def build_ir_module(

if load:
# load function
return load_lib_func(lib_path, 'hidet_' + func_name, func_type=func_type, src_path=src_path)
return load_lib_func(lib_path, 'hidet_launch', func_type=func_type, src_path=src_path)
else:
return lib_path, func_name, func_type
return lib_path, func_type


def _build_ir_module_job(args) -> Optional[Tuple[str, str, FuncType]]:
ir_module, func_name, output_dir, dumped_options = args
def _build_ir_module_job(args) -> Optional[Tuple[str, FuncType]]:
ir_module, output_dir, dumped_options = args
option.restore_options(dumped_options)
try:
return build_ir_module(
ir_module, func_name, output_dir, save_ir=False, profile_pass=False, load=False, use_hash_dir=False
)
return build_ir_module(ir_module, output_dir, save_ir=False, profile_pass=False, load=False, use_hash_dir=False)
except subprocess.CalledProcessError:
print('Failed launch subprocess to compile the lowered source code via nvcc.')
return None
Expand All @@ -199,7 +193,7 @@ def _build_ir_module_job(args) -> Optional[Tuple[str, str, FuncType]]:


def build_ir_module_batch(
ir_modules: Sequence[IRModule], func_name: str, output_dir: str, parallel=True, verbose=False
ir_modules: Sequence[IRModule], output_dir: str, parallel=True, verbose=False
) -> List[Optional[CompiledFunction]]:
"""
Build a batch of ir modules.
Expand All @@ -209,9 +203,6 @@ def build_ir_module_batch(
ir_modules: Sequence[IRModule]
A sequence of ir modules to build.
func_name: str
The name of the function to load after building.
output_dir: str
The output directory to save the compiled library and source code (lib.so and source.cu).
Expand All @@ -230,8 +221,7 @@ def build_ir_module_batch(
with Timer() as timer:
dumped_options = option.dump_options()
jobs = [
(ir_module, func_name, os.path.join(output_dir, str(idx)), dumped_options)
for idx, ir_module in enumerate(ir_modules)
(ir_module, os.path.join(output_dir, str(idx)), dumped_options) for idx, ir_module in enumerate(ir_modules)
]
build_results = []
if parallel:
Expand Down Expand Up @@ -260,8 +250,8 @@ def build_ir_module_batch(
funcs: List[Optional[CompiledFunction]] = []
for build_result in build_results:
if build_result is not None:
lib_path, func_name, func_type = build_result
funcs.append(load_lib_func(lib_path, 'hidet_' + func_name, func_type))
lib_path, func_type = build_result
funcs.append(load_lib_func(lib_path, 'hidet_launch', func_type))
else:
funcs.append(None)
if verbose:
Expand Down
10 changes: 3 additions & 7 deletions python/hidet/ffi/packedfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ctypes import c_int32, c_void_p, pointer, c_float, cast
from ctypes import POINTER, Structure
from hidet.ir.type import TypeNode, DataType, TensorType, PointerType, TensorPointerType
from hidet.utils.py import same_list
from .ffi import _LIB

c_int32_p = POINTER(c_int32)
Expand Down Expand Up @@ -103,24 +102,21 @@ def convert_args(self, args: Sequence):
elif isinstance(arg, Tensor):
if isinstance(param_type, TensorType):
expect_dtype = param_type.dtype
expect_shape = param_type.const_shape()
elif isinstance(param_type, TensorPointerType):
expect_dtype = param_type.tensor_type.dtype
expect_shape = param_type.tensor_type.const_shape()
elif isinstance(param_type, PointerType):
isinstance(param_type.base_type, DataType)
expect_dtype = param_type.base_type
expect_shape = None
else:
raise ValueError(
'The callee expects the {}-th element to be a {}, but got a {}.'.format(
i + 1, param_type, type(arg)
)
)
if arg.dtype != expect_dtype or (expect_shape is not None and not same_list(arg.shape, expect_shape)):
if arg.dtype != expect_dtype:
raise ValueError(
'The callee expects the {}-th element to be a {}{}, but got a {}{}.'.format(
i + 1, expect_dtype, expect_shape if expect_shape else " tensor", arg.dtype, arg.shape
'The callee expects the {}-th element to be a {} tensor, but got a {} tensor.'.format(
i + 1, expect_dtype, arg.dtype
)
)
converted_args.append(cast(arg.storage.addr, c_void_p))
Expand Down
49 changes: 34 additions & 15 deletions python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,12 @@ def forward_context() -> GraphForwardContext:
class FlowGraph:
"""The computation graph representation."""

def __init__(self, outputs: List[Tensor], inputs=None, nodes=None):
self.outputs: List[Tensor] = outputs
self.inputs: Optional[List[Tensor]] = inputs
self.nodes: Optional[List[Operator]] = nodes
self.usage_count: Optional[Dict[Tensor, int]] = None
def __init__(self, outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] = None, nodes=None):
self.outputs: List[Tensor] = list(outputs)
self.inputs: Optional[List[Tensor]] = list(inputs) if inputs is not None else None
self._nodes: Optional[List[Operator]] = nodes
self._usage_count: Optional[Dict[Tensor, int]] = None
self.update_nodes()

def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
"""Run the computation graph.
Expand All @@ -111,14 +112,12 @@ def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
return self.forward(*inputs)

def __str__(self):
if any(v is None for v in [self.inputs, self.nodes, self.usage_count]):
self.update_nodes()
namer = Namer()

def get_tensor_sig(x: Tensor) -> Doc:
return Text(x.dtype.name) + '[' + doc_join([str(v) for v in x.shape], ', ') + ']'

def get_attr_repr(value: Union[float, int, bool, str, list, tuple]) -> Doc:
def get_attr_repr(value: Union[float, int, bool, str, list, tuple, FlowGraph]) -> Doc:
if isinstance(value, (float, int, bool)):
return Text(str(value))
elif isinstance(value, str):
Expand All @@ -127,6 +126,8 @@ def get_attr_repr(value: Union[float, int, bool, str, list, tuple]) -> Doc:
return '[' + doc_join([get_attr_repr(v) for v in value], ', ') + ']'
elif isinstance(value, tuple):
return '(' + doc_join([get_attr_repr(v) for v in value], ', ') + ')'
elif isinstance(value, FlowGraph):
return Text('FlowGraph({})'.format(', '.join(u.name for u in value.nodes)))
else:
return Text(str(value))

Expand All @@ -153,7 +154,7 @@ def get_attr_repr(value: Union[float, int, bool, str, list, tuple]) -> Doc:
output: Tensor = outputs[0]
line_doc = Doc()
line_doc += namer(output) + ': ' + get_tensor_sig(output) + ' = '
line_doc += op.name + ('*' if len(op.task.task_graph.nodes) > 1 else '') + '('
line_doc += op.name + '('
line_doc += doc_join([namer(x) for x in op.inputs], sep=', ')
if op.attrs:
line_doc += ', ' + doc_join(
Expand All @@ -168,6 +169,24 @@ def get_attr_repr(value: Union[float, int, bool, str, list, tuple]) -> Doc:
graph_doc = head_doc + '{' + const_doc.indent() + body_doc.indent() + NewLine() + '}'
return str(graph_doc)

@property
def nodes(self) -> List[Operator]:
"""The list of operators in the computation graph."""
if self._nodes is None:
self.update_nodes()
return self._nodes

@property
def usage_count(self) -> Dict[Tensor, int]:
"""The usage count of each tensor in the computation graph."""
if self._usage_count is None:
self.update_nodes()
return self._usage_count.copy()

def invalid_cache(self):
self._nodes = None
self._usage_count = None

def build(self):
tasks = []
tunable_tasks = []
Expand All @@ -179,7 +198,10 @@ def build(self):
if task_key in task_keys:
continue
task_keys.add(task_key)
if search_space == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
if search_space == 0 or all(
method not in node.task.__class__.__dict__
for method in ['implement_cuda', 'implement_cpu', 'implement']
):
tasks.append(node.task)
else:
tunable_tasks.append(node.task)
Expand Down Expand Up @@ -211,8 +233,6 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
if tensor.storage is None:
msg = 'Expect non-symbolic input tensors, got symbolic input {} ({}).'.format(idx, tensor.signature())
raise ValueError(msg)
if any(v is None for v in [self.inputs, self.nodes, self.usage_count]):
self.update_nodes()
self.build()

GraphForwardContext.current()._trigger_before_graph(self, inputs)
Expand Down Expand Up @@ -284,7 +304,7 @@ def save(self, model_file: str):
# before save, clear the packed func cache because ctypes object can not be pickled
for node in self.nodes:
node.task_func = None
self.usage_count, self.nodes = None, None
self._usage_count, self._nodes = None, None

dirname = os.path.dirname(model_file)
os.makedirs(dirname, exist_ok=True)
Expand All @@ -311,11 +331,10 @@ def load(model_file: str) -> FlowGraph:
ret = pickle.load(f)
if not isinstance(ret, FlowGraph):
raise TypeError('Expect to load FlowGraph, got {}'.format(type(ret)))
ret.update_nodes()
return ret

def update_nodes(self):
free_vars, self.nodes, self.usage_count = self._analyze(self.outputs)
free_vars, self._nodes, self._usage_count = self._analyze(self.outputs)
if self.inputs:
non_bound_free_vars: Set[Tensor] = set(free_vars) - set(self.inputs)
if len(non_bound_free_vars) > 0:
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .definitions.transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape
from .definitions.transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims
from .definitions.transform import permute_dims
from .definitions.fusion import fused_operator
from .definitions.special import barrier

from .definitions import utils
Expand Down
28 changes: 3 additions & 25 deletions python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,9 @@
# limitations under the License.
# pylint: disable=redefined-builtin
from .create import full
from .arithmetic import (
add,
subtract,
multiply,
divide,
negative,
sqrt,
rsqrt,
where,
maximum,
minimum,
reciprocal,
exp,
log,
abs,
)
from .arithmetic import (
bitwise_and,
bitwise_invert,
bitwise_or,
bitwise_xor,
ceil,
bitwise_right_shift,
bitwise_left_shift,
)
from .arithmetic import add, subtract, multiply, divide, negative, sqrt, rsqrt, where, maximum, minimum, reciprocal
from .arithmetic import bitwise_and, bitwise_invert, bitwise_or, bitwise_xor, bitwise_right_shift, bitwise_left_shift
from .arithmetic import ceil, exp, log, abs
from .compare import equal, less, greater, less_equal, greater_equal, logical_not, logical_and
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, split, pad, conv_pad
from .pool import avg_pool2d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
Expand Down
Empty file.

0 comments on commit 3cc75b6

Please sign in to comment.