Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fixbug] Call fuse in reduce_fp16 operator #105

Merged
merged 5 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,9 @@ def visit_AsmStmt(self, stmt: AsmStmt):
)

def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
assert isinstance(stmt.func_var, Var)
return NewLine() + Text('{}<<<dim3({}), dim3({}), {}, {}>>>({});').format(
self(stmt.func_var),
self.canonize_funcname(stmt.func_var.hint),
self(stmt.grid_dim),
self(stmt.block_dim),
self(stmt.shared_mem_bytes),
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/cli/bench/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def inputs_str(self) -> str:
items.append('{}={}'.format(k, self.tensor_str(v)))
return ', '.join(items)

def bench_with_backend(self, backend: str, mode=None, passes=None, warmup=3, number=10, repeat=10):
def bench_with_backend(self, backend: str, mode=None, warmup=3, number=10, repeat=10):
import torch.backends.cudnn
import torch.backends.cuda

Expand All @@ -97,7 +97,7 @@ def bench_with_backend(self, backend: str, mode=None, passes=None, warmup=3, num
kwargs = {k: v.cuda() for k, v in kwargs.items()}
dynamo.reset()
with torch.no_grad():
model_opt = torch.compile(model, backend=backend, mode=mode, passes=passes)
model_opt = torch.compile(model, backend=backend, mode=mode)
latency = benchmark_func(
run_func=lambda: model_opt(*args, **kwargs), warmup=warmup, number=number, repeat=repeat
)
Expand Down
19 changes: 4 additions & 15 deletions python/hidet/graph/frontend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,7 @@ def from_torch(module, concrete_args=None):


def register_dynamo_backends():
"""
Register the 'hidet' and 'onnx2hidet' backends for torch dynamo.

By default, if torch has already been imported and dynamo is available, the backends will be registered immediately.
Otherwise, the user can call this function to register the backends manually.
"""
from torch._dynamo.backends.registry import register_backend, list_backends
from .dynamo_backends import hidet_backend

if 'hidet' not in list_backends():
register_backend(hidet_backend, name='hidet')


if imported() and dynamo_available():
register_dynamo_backends()
print(
'Now, hidet will use the entry_points mechanism to register as a dynamo backend. \n'
'Feel free to remove the line `hidet.frontend.torch.register_dynamo_backends()` in your code.'
)
24 changes: 24 additions & 0 deletions python/hidet/graph/frontend/torch/dynamo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional


class DynamoConfig:
def __init__(self):
self._search_space: int = 0
self._parallel_k: str = 'default'
self._use_fp16: bool = False
self._use_fp16_reduction: bool = False
self._use_cuda_graph: bool = True
self._use_tensor_core: bool = True
self._print_input_graph: bool = False
self._dump_graph_ir: Optional[str] = None
self._correctness_report: bool = False

def __getitem__(self, item: str):
Expand All @@ -43,6 +48,13 @@ def search_space(self, level: int = 2):
self._search_space = level
return self

def use_tensor_core(self, flag=True):
"""
Whether to use tensor core
"""
self._use_tensor_core = flag
return self

def parallel_k(self, strategy="default"):
"""
Parallelization on k dimension of the matrix multiplication
Expand Down Expand Up @@ -91,6 +103,18 @@ def print_input_graph(self, flag=True):
self._print_input_graph = flag
return self

def dump_graph_ir(self, output_dir: str):
"""
Whether to dump the graph ir

Parameters
----------
output_dir: str
The output directory to dump the graph ir.
"""
self._dump_graph_ir = output_dir
return self

def correctness_report(self, flag=True):
"""
Whether to check correctness and print report error
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def update_nodes(self):
f'When there are multiple free '
f'variables, it is mandatory to specify the "inputs" argument explicitly when calling '
f'hidet.trace_from(...):\n'
' hidet.trace_from(..., free_vars=[tensor1, tensor2, ...])\n'
' hidet.trace_from(..., inputs=[tensor1, tensor2, ...])\n'
)
self.inputs = free_vars
return self
Expand Down
11 changes: 6 additions & 5 deletions python/hidet/graph/ops/definitions/matmul/matmul_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def allow_prologue(self) -> bool:
return False

def allow_epilogue(self) -> bool:
return True
return False

def implement_cuda(self, working_dir: str) -> IRModule:
return tune.tune(self.schedule, task=self, target_device='cuda', working_dir=working_dir)
Expand All @@ -87,8 +87,8 @@ def implement_cuda(self, working_dir: str) -> IRModule:
@tune.space(2, 'warp_k', [8, 16, 32, 64])
@tune.space(2, 'mma', ['m16n8k16'])
@tune.space(1, 'block_m', [128])
@tune.space(1, 'block_n', [64])
@tune.space(1, 'block_k', [32])
@tune.space(1, 'block_n', [128])
@tune.space(1, 'block_k', [16])
@tune.space(1, 'warp_m', [64])
@tune.space(1, 'warp_n', [64])
@tune.space(1, 'warp_k', [16])
Expand All @@ -105,7 +105,7 @@ def schedule(
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads, dynamic_shared_memory
from hidet.lang.cuda import MmaConfig, mma_sync, cp_async, cp_async_wait_all, ldmatrix
from hidet.lang.cuda import register_tensor
from hidet.transforms.tools import add_packed_func
from hidet.transforms.tools import fuse_and_pack

# input shapes
node_a, node_b, node_c = self.inputs[0], self.inputs[1], self.outputs[0]
Expand Down Expand Up @@ -282,6 +282,7 @@ def matmul_f16_kernel(
offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
c_head_index = spatial(*c_head).map(blockIdx.z)
gmem_c = c[c_head_index][offset_m:, offset_n:]

for k_round in range(warp_count_k):
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id):
if wk == k_round:
Expand All @@ -302,7 +303,7 @@ def matmul_f16_kernel(

ir_module = module.ir_module()
assert isinstance(matmul_f16_kernel, Function)
add_packed_func(ir_module, matmul_f16_kernel, self.name)
fuse_and_pack(ir_module, matmul_f16_kernel, task=self)

return ir_module

Expand Down
5 changes: 3 additions & 2 deletions python/hidet/graph/ops/definitions/reduce/reduce_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from hidet.ir.layout import DataLayout
from hidet.lang import f16, f32, spatial, repeat, attr, tensor_pointer
from hidet.lang.cuda import blockIdx, threadIdx, register_tensor
from hidet.transforms.tools import add_packed_func
from hidet.transforms.tools import add_packed_func, fuse_and_pack
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode, ReduceType
from hidet.graph.ops.definitions.utils import compute, input_like, normalize_dim
from hidet.utils import prod
Expand Down Expand Up @@ -81,6 +81,7 @@ def allow_prologue(self) -> bool:
return False

def allow_epilogue(self) -> bool:
# return False
rank = len(self.inputs[0].const_shape())
if rank - 1 in self.dims: # pylint: disable=simplifiable-if-statement
# use self.cuda_schedule_reduce_by_warp
Expand Down Expand Up @@ -169,7 +170,7 @@ def reduce_kernel(x: f16[x.const_shape()], y: f16[y.const_shape()]):
y.write(indices, rv[0], protected=False)

ir_module = module.ir_module()
add_packed_func(ir_module, func=reduce_kernel, pack_func_name=self.name)
fuse_and_pack(ir_module, reduce_kernel, task=self)
return ir_module

def cuda_schedule_reduce_by_default(self) -> IRModule:
Expand Down
47 changes: 43 additions & 4 deletions python/hidet/transforms/tools/apply_prologue_epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def __init__(self, task: Task):
self.input2task[input_tensor] = internal_task

self.binding: Dict[TensorNode, Var] = {}
self.new_params: List[Var] = []
self.anchor_inputs: List[Var] = []
self.anchor_outputs: List[Var] = []
self.anchor_input_new_index: Dict[int, int] = {}
self.anchor_output_new_index: Dict[int, int] = {}

def visit_Function(self, func: Function):
anchor_num_inputs = len(self.task.inputs)
Expand All @@ -73,23 +76,59 @@ def visit_Function(self, func: Function):
self.anchor_inputs: List[Var] = func.params[:anchor_num_inputs]
self.anchor_outputs: List[Var] = func.params[anchor_num_inputs:]

for input_index in range(anchor_num_inputs):
tn = self.task.inputs[input_index]
while tn in self.task_graph.consume:
tn = self.task_graph.consume[tn]
if tn in self.task_graph.input_tensors:
self.anchor_input_new_index[input_index] = self.task_graph.input_tensors.index(tn)

for output_index in range(anchor_num_outputs):
tn = self.task.outputs[output_index]
if tn in self.task_graph.output_tensors:
self.anchor_output_new_index[output_index] = self.task_graph.output_tensors.index(tn)

# create parameters for fused function, and bind task graph parameters to function parameters
# todo: do not create new parameters for the inputs/outputs that have not been fused
new_params: List[Var] = []
self.new_params: List[Var] = []
for tensor_node in self.task_graph.input_tensors + self.task_graph.output_tensors:
new_params.append(Var(tensor_node.name, tensor_node.ttype))
self.binding[tensor_node] = new_params[-1]
self.new_params.append(Var(tensor_node.name, tensor_node.ttype))
self.binding[tensor_node] = self.new_params[-1]

return Function(
name=func.name,
params=new_params,
params=self.new_params,
body=self.visit(func.body),
ret_type=func.ret_type,
kind=func.kind,
extern_vars=func.extern_vars,
attrs=func.attrs,
)

def visit_Var(self, e: Var):
if e in self.anchor_inputs:
input_index = self.anchor_inputs.index(e)
if input_index in self.anchor_input_new_index:
return self.new_params[self.anchor_input_new_index[input_index]]
else:
# we encounter a usage of an input tensor of the task other than TensorElement and BufferStoreStmt
raise ValueError(
'Did you used a tensor in expression other than tensor[...] and tensor[...] = ...'
' while marking the task as allowing prologue?'
)
elif e in self.anchor_outputs:
output_index = self.anchor_outputs.index(e)
if output_index in self.anchor_output_new_index:
return self.new_params[len(self.task_graph.input_tensors) + self.anchor_output_new_index[output_index]]
else:
# we encounter a usage of an output tensor of the task other than TensorElement and BufferStoreStmt
raise ValueError(
'Did you used a tensor in expression other than tensor[...] and tensor[...] = ...'
' while marking the task as allowing epilogue?'
)
else:
return e

def visit_TensorElement(self, e: TensorElement):
if isinstance(e.base, TensorNode):
# apply prologue
Expand Down