Skip to content

Commit

Permalink
[inductor] generate triton kernel benchmark (#95506)
Browse files Browse the repository at this point in the history
A PR to generate benchmark code for individual triton kernels. We can explore improving autotuning with the saved compiled kernel directly. This potentially can speedup our iteration and separate the concern with the upstream components that generate the compiled module.

Since I'm still ramping up on inductor, I'll reflect what I learned here so people can correct me if I'm wrong.  In inductor, WrapperCodeGen class is used to generate the compiled module for CUDA (or triton). Here is an example compiled module for a toy model like: `def f(x): return sin(x) + cos(x)` https://gist.github.com/shunting314/c6ed9f571919e3b414166f1696dcc61b .  A compiled module contains the following part:
- various triton kernels
- a wrapper (or a method named call . The name is hardcoded) that calls the triton kernels and potentially ATen kernels to efficiently do the same work as the original Fx graph being compiled by inductor
- some utility code that generate random inputs and run the wrapper

The triton kernels in the compiled module are annotated with decorator like pointwise which is used for autotuning.

This PR add a config so enabling it will just trigger the path of the compiled module being printed. It can be controlled from environment variable as well.

The path to each compiled triton kernel is added as comment in the compiled module. E.g.
```
# kernel path: /tmp/torchinductor_shunting/gn/cgn6x3mqoltu7q77gjnu2elwfupinsvcovqwibc6fhsoiy34tvga.py
triton__0 = async_compile.triton('''
import triton
import triton.language as tl
...
""")
````

Example command:
```
TORCHINDUCTOR_OUTPUT_COMPILED_MODULE_PATH=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --training --dashboard --only AlbertForMaskedLM --disable-cudagraphs
```

Pull Request resolved: pytorch/pytorch#95506
Approved by: https://github.com/Chillee
  • Loading branch information
shunting314 authored and cyyever committed Mar 5, 2023
1 parent 9062417 commit 1c5042d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 27 deletions.
100 changes: 86 additions & 14 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..._dynamo import config as dynamo_config
from .. import config, ir, scheduler
from ..codecache import get_code_path
from ..ir import ReductionHint
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..utils import (
Expand Down Expand Up @@ -1173,6 +1174,78 @@ def codegen_body(self):
self.stores.clear()
self.suffix.clear()

def codegen_kernel_benchmark(self):
result = IndentedBuffer()
argdefs, call_args, signature = self.args.python_argdefs()

result.writelines(["", "", "def get_args():"])
with result.indent():
for arg_name in call_args:
buf = V.graph.get_buffer(arg_name)
if buf:
result.writeline(
f"{arg_name} = rand_strided({tuple(buf.get_size())}, {tuple(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
)
elif arg_name in V.graph.constants:
# note that random seed is put in V.graph.constants
const_tensor = V.graph.constants[arg_name]
result.writeline(
f"{arg_name} = rand_strided({tuple(const_tensor.size())}, {tuple(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
)
else:
raise KeyError(
f"Don't find the buffer or const tensor for {arg_name}"
)
result.writeline(f"return {', '.join(call_args)},")

result.writelines(["\n", "\n", "def call(args):"])
grid = []
extra_args = []
with result.indent():
index = V.graph.scheduler.current_device.index
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
with result.indent():
result.writeline(
f"torch.cuda.set_device({index})"
) # no-op to ensure context
for tree in self.range_trees:
expr = pexpr(tree.numel)
if tree.prefix != "r" or self.inside_reduction:
extra_args.append(expr)
if tree.prefix != "r":
grid.append(expr)

stream_name = f"stream{index}"
result.writeline(f"{stream_name} = get_cuda_stream({index})")
extra_args_str = ", ".join(map(str, extra_args)) + ", "
result.writeline(
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
)

result.writelines(["\n", "\n", "if __name__ == '__main__':"])
with result.indent():
result.writeline(
"from torch._C import _cuda_getCurrentRawStream as get_cuda_stream"
)
result.writeline("from torch._dynamo.testing import rand_strided")
result.writeline("from torch._inductor.utils import get_num_bytes")
result.writeline("import torch")
result.writeline("from torch._inductor.triton_ops.autotune import grid")
result.writeline("from triton.testing import do_bench")
result.writeline("")

result.writeline("args = get_args()")
result.writeline(
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)[0]"
)
result.writeline("num_gb = get_num_bytes(*args) / 1e9")
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
result.writeline(
'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")'
)

return result

def codegen_kernel(self, name=None):
from triton import next_power_of_2

Expand Down Expand Up @@ -1279,21 +1352,13 @@ def codegen_kernel(self, name=None):
code.writeline(f"{old} = {new}")
code.splice(self.body)

if config.benchmark_kernel:
code.splice(self.codegen_kernel_benchmark())

if name is not None:
return code.getvalue()

wrapper = IndentedBuffer()
wrapper.writeline("async_compile.triton('''")
wrapper.splice(code.getvalue(), strip=True)
wrapper.writeline("''')")
return wrapper.getvalue()

def codegen_template_wrapper(self, src_code):
wrapper = IndentedBuffer()
wrapper.writeline("async_compile.triton('''")
wrapper.splice(src_code, strip=True)
wrapper.writeline("''')")
return wrapper.getvalue()
return code.getvalue()

def codegen_static_numels(self, code):
"""
Expand Down Expand Up @@ -1586,7 +1651,14 @@ def define_kernel(self, src_code, node_schedule):
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
src_code = src_code.replace("#pragma CMT", "#")
wrapper.define_kernel(kernel_name, src_code)

_, _, kernel_path = get_code_path(src_code, "py", extra="")
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline("async_compile.triton('''")
compile_wrapper.splice(src_code, strip=True)
compile_wrapper.writeline("''')")

wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), kernel_path)
return kernel_name

def codegen_template(self, template_node, epilogue_nodes):
Expand All @@ -1603,7 +1675,7 @@ def codegen_template(self, template_node, epilogue_nodes):
for node in epilogue_nodes:
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))

src_code = kernel.codegen_template_wrapper(render())
src_code = render()
kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
self.scheduler.free_buffers()
Expand Down
5 changes: 3 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,9 @@ def add_expr_input(name, val):
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
)

def define_kernel(self, name: str, kernel: str):
self.header.splice(f"\n\n{name} = {kernel}")
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else ""
self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}")

def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
return
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@

comment_origin = False

benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"


def is_fbcode():
return not hasattr(torch.version, "git_version")
Expand Down
11 changes: 9 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def warn_fallback(self, name):
def fake_mode(self):
return V.fake_mode

def get_buffer(self, buffer_name: str):
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name]
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name]
return None

def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
Expand Down Expand Up @@ -599,8 +606,8 @@ def compile_to_module(self):
for name, value in self.constants.items():
setattr(mod, name, value)

if dynamo_config.output_code:
log.info("Output code: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.rename(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
Expand Down
18 changes: 9 additions & 9 deletions torch/_inductor/triton_ops/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from .. import config
from ..codecache import cache_dir
from ..ir import ReductionHint, TileHint
from ..utils import ceildiv, conditional_product, do_bench, has_triton, next_power_of_2
from ..utils import (
ceildiv,
conditional_product,
do_bench,
get_num_bytes,
has_triton,
next_power_of_2,
)
from .conv_perf_model import (
early_config_prune as conv_early_config_prune,
estimate_conv_time,
Expand Down Expand Up @@ -238,18 +245,11 @@ def run(self, *args, grid, stream):
super().run(*args, grid=grid, stream=stream)
(launcher,) = self.launchers

def get_num_bytes(*args):
return sum(
arg.numel() * arg.element_size()
for arg in args
if isinstance(arg, torch.Tensor)
)

ms = self.bench(launcher, *args, grid=grid)[0]
num_gb = get_num_bytes(*args) / 1e9
gb_per_s = num_gb / (ms / 1e3)

collected_calls.append((kernel_name, ms, num_gb, 1e3 * num_gb / ms))
collected_calls.append((kernel_name, ms, num_gb, gb_per_s))
import colorama

info_str = f"{kernel_name}\t {ms:.3f}ms\t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s"
Expand Down
11 changes: 11 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,14 @@ def developer_warning(msg):
log.warning(msg)
else:
log.info(msg)


def get_num_bytes(*args):
"""
Return the total number of bytes the arguments of tensor type takes.
"""
return sum(
arg.numel() * arg.element_size()
for arg in args
if isinstance(arg, torch.Tensor)
)

0 comments on commit 1c5042d

Please sign in to comment.