forked from pytorch/pytorch
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inductor] Triton template for conv (pytorch#422)
* TritonTemplateKernel, template_codegen, conv jinja2 template * pip install Jinja2 in setup_nightly
- Loading branch information
Showing
5 changed files
with
546 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from torchinductor.triton_ops.conv_perf_model import early_config_prune | ||
from torchinductor.triton_ops.conv_perf_model import estimate_conv_time | ||
|
||
|
||
@triton.autotune( | ||
configs=[ | ||
# basic configs for compute-bound matmuls | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8), | ||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8), | ||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2), | ||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8), | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4), | ||
# good for int8 | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8), | ||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8), | ||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4), | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2), | ||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2), | ||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2), | ||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=5, num_warps=2), | ||
], | ||
# the above configs will be evaluated anytime the key changes | ||
key=[ | ||
"BATCH", | ||
"IN_C", | ||
"IN_H", | ||
"IN_W", | ||
"KERNEL_N", | ||
"KERNEL_H", | ||
"KERNEL_W", | ||
"OUT_H", | ||
"OUT_W", | ||
# parameters of conv | ||
"stride_h", | ||
"stride_w", | ||
"padding_h", | ||
"padding_w", | ||
"dilation_h", | ||
"dilation_w", | ||
"output_padding_h", | ||
"output_padding_w", | ||
"groups", | ||
], | ||
prune_configs_by={ | ||
"early_config_prune": early_config_prune, | ||
"perf_model": estimate_conv_time, | ||
"top_k": 10, | ||
}, | ||
) | ||
@triton.jit | ||
def {{kernel_name}}( | ||
x, | ||
w, | ||
bias, | ||
y, | ||
{% for i in extra_args %} | ||
{{i}}, | ||
{% endfor %} | ||
# stride of tensor | ||
stride_xn, | ||
stride_xc, | ||
stride_xh, | ||
stride_xw, | ||
stride_wn, | ||
stride_wc, | ||
stride_wh, | ||
stride_ww, | ||
stride_yn, | ||
stride_yc, | ||
stride_yh, | ||
stride_yw, | ||
stride_biasn, | ||
# pointer inc for x | ||
delta_x_ptr, | ||
# Tensor dimensions | ||
BATCH, | ||
IN_C, | ||
IN_H, | ||
IN_W, | ||
KERNEL_N, | ||
KERNEL_H, | ||
KERNEL_W, | ||
OUT_H, | ||
OUT_W, | ||
# parameters of conv | ||
stride_h, | ||
stride_w, | ||
padding_h, | ||
padding_w, | ||
dilation_h, | ||
dilation_w, | ||
output_padding_h, | ||
output_padding_w, | ||
groups, | ||
# Metaparameters | ||
ACC_TYPE: tl.constexpr, | ||
CONV1X1_NHWC: tl.constexpr, | ||
# blocks in different dimension | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
# reduction tiling parameter for matmul | ||
BLOCK_K: tl.constexpr, | ||
): | ||
""" | ||
each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y | ||
""" | ||
# ----------------------------------------------------------- | ||
# Map program ids `pid` to the block of y it should compute. | ||
pid_nhw = tl.program_id(0) | ||
pid_k = tl.program_id(1) | ||
|
||
# offset for output y | ||
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) | ||
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) | ||
off_y_n = off_y_nhw // (OUT_H * OUT_W) | ||
off_y_hw = off_y_nhw % (OUT_H * OUT_W) | ||
off_y_h = off_y_hw // OUT_W | ||
off_y_w = off_y_hw % OUT_W | ||
|
||
# offset for the initial ptr for x | ||
off_x_n = off_y_n | ||
off_x_h = off_y_h * stride_h - padding_h | ||
off_x_w = off_y_w * stride_w - padding_w | ||
off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw | ||
off_x_crs = tl.arange(0, BLOCK_K) | ||
|
||
CRS = IN_C * KERNEL_H * KERNEL_W | ||
# load inc ptr of x, upade x_ptrs | ||
if not CONV1X1_NHWC: | ||
delta_x_ptrs = delta_x_ptr + off_x_crs | ||
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS) | ||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] | ||
else: | ||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :] | ||
|
||
mask_x = ( | ||
(off_x_n < BATCH) | ||
& (off_x_h >= 0) | ||
& (off_x_h < IN_H) | ||
& (off_x_w >= 0) | ||
& (off_x_w < IN_W) | ||
)[:, None] & (off_x_crs < CRS)[None, :] | ||
|
||
# offset for the inital ptr for w | ||
off_w_crs = tl.arange(0, BLOCK_K) | ||
off_w_k = off_y_k | ||
w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn | ||
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] | ||
|
||
# ------ load x ------ | ||
matrix_x = tl.load(x_ptrs, mask=mask_x) | ||
# ------ load w ------ | ||
matrix_w = tl.load(w_ptrs, mask=mask_w) | ||
|
||
# ----------------------------------------------------------- | ||
# allocate accumulator | ||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) | ||
for crs in range(0, CRS, BLOCK_K): | ||
|
||
# ------ matrix multiplication ------ | ||
acc += tl.dot(matrix_x, matrix_w) | ||
# ------ update ptrs ------ | ||
w_ptrs += BLOCK_K | ||
# load inc ptr of x, upade x_ptrs | ||
if not CONV1X1_NHWC: | ||
delta_x_ptrs += BLOCK_K | ||
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) | ||
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS) | ||
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :] | ||
else: | ||
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K) | ||
x_ptrs += BLOCK_K | ||
|
||
mask_x = ( | ||
(off_x_n < BATCH) | ||
& (off_x_h >= 0) | ||
& (off_x_h < IN_H) | ||
& (off_x_w >= 0) | ||
& (off_x_w < IN_W) | ||
)[:, None] & (off_x_crs < CRS)[None, :] | ||
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :] | ||
# ------ prefetch ------ | ||
# ------ load x ------ | ||
matrix_x = tl.load(x_ptrs, mask=mask_x) | ||
# ------ load w ------ | ||
matrix_w = tl.load(w_ptrs, mask=mask_w) | ||
|
||
{{ pointwise_computation }} | ||
{# | ||
z = tl.load(z_ptrs, mask=mask_z) | ||
acc += z | ||
#} | ||
|
||
# add bias if is not None | ||
if bias is not None: | ||
off_bias_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) | ||
bias_ptrs = bias + off_bias_k * stride_biasn | ||
mask_bias = off_bias_k < KERNEL_N | ||
_bias = tl.load(bias_ptrs, mask=mask_bias) | ||
acc += _bias[None, :] | ||
|
||
|
||
acc = acc.to(y.dtype.element_ty) | ||
|
||
# rematerialize -- this saves some registers | ||
# offset for output y | ||
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N) | ||
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M) | ||
off_y_n = off_y_nhw // (OUT_H * OUT_W) | ||
off_y_hw = off_y_nhw % (OUT_H * OUT_W) | ||
# consider output padding | ||
off_y_h = off_y_hw // OUT_W + output_padding_h | ||
off_y_w = off_y_hw % OUT_W + output_padding_w | ||
|
||
# y ptrs in the block of [BLOCK_M, BLOCK_N] | ||
y_ptrs = ( | ||
y | ||
+ off_y_n[:, None] * stride_yn | ||
+ off_y_h[:, None] * stride_yh | ||
+ off_y_w[:, None] * stride_yw | ||
+ off_y_k[None, :] * stride_yc | ||
) | ||
|
||
# out-of-bounds check | ||
mask_y = ( | ||
(off_y_n < BATCH)[:, None] | ||
& (off_y_h < OUT_H + output_padding_h)[:, None] | ||
& (off_y_w < OUT_W + output_padding_w)[:, None] | ||
& (off_y_k < KERNEL_N)[None, :] | ||
) | ||
|
||
tl.store(y_ptrs, acc, mask=mask_y) | ||
|
||
return | ||
|
Oops, something went wrong.