Skip to content

Commit

Permalink
[Inductor] Triton template for conv (pytorch#422)
Browse files Browse the repository at this point in the history
* TritonTemplateKernel, template_codegen, conv jinja2 template
* pip install Jinja2 in setup_nightly
  • Loading branch information
pyjhzwh committed Jun 20, 2022
1 parent 9aded7d commit 20e8e04
Show file tree
Hide file tree
Showing 5 changed files with 546 additions and 10 deletions.
1 change: 1 addition & 0 deletions Makefile
Expand Up @@ -43,6 +43,7 @@ setup:

setup_nightly:
pip install ninja
pip install Jinja2
pip install --pre torch==1.13.0.dev20220601+cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install -v git+https://github.com/pytorch/functorch.git@ae70048d9ff538062207922e37337
pip install -r requirements.txt
Expand Down
246 changes: 246 additions & 0 deletions torchinductor/codegen/triton_conv.j2
@@ -0,0 +1,246 @@
import torch
import triton
import triton.language as tl

from torchinductor.triton_ops.conv_perf_model import early_config_prune
from torchinductor.triton_ops.conv_perf_model import estimate_conv_time


@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4),
# good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=5, num_warps=2),
],
# the above configs will be evaluated anytime the key changes
key=[
"BATCH",
"IN_C",
"IN_H",
"IN_W",
"KERNEL_N",
"KERNEL_H",
"KERNEL_W",
"OUT_H",
"OUT_W",
# parameters of conv
"stride_h",
"stride_w",
"padding_h",
"padding_w",
"dilation_h",
"dilation_w",
"output_padding_h",
"output_padding_w",
"groups",
],
prune_configs_by={
"early_config_prune": early_config_prune,
"perf_model": estimate_conv_time,
"top_k": 10,
},
)
@triton.jit
def {{kernel_name}}(
x,
w,
bias,
y,
{% for i in extra_args %}
{{i}},
{% endfor %}
# stride of tensor
stride_xn,
stride_xc,
stride_xh,
stride_xw,
stride_wn,
stride_wc,
stride_wh,
stride_ww,
stride_yn,
stride_yc,
stride_yh,
stride_yw,
stride_biasn,
# pointer inc for x
delta_x_ptr,
# Tensor dimensions
BATCH,
IN_C,
IN_H,
IN_W,
KERNEL_N,
KERNEL_H,
KERNEL_W,
OUT_H,
OUT_W,
# parameters of conv
stride_h,
stride_w,
padding_h,
padding_w,
dilation_h,
dilation_w,
output_padding_h,
output_padding_w,
groups,
# Metaparameters
ACC_TYPE: tl.constexpr,
CONV1X1_NHWC: tl.constexpr,
# blocks in different dimension
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
# reduction tiling parameter for matmul
BLOCK_K: tl.constexpr,
):
"""
each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of y it should compute.
pid_nhw = tl.program_id(0)
pid_k = tl.program_id(1)

# offset for output y
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
off_y_n = off_y_nhw // (OUT_H * OUT_W)
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
off_y_h = off_y_hw // OUT_W
off_y_w = off_y_hw % OUT_W

# offset for the initial ptr for x
off_x_n = off_y_n
off_x_h = off_y_h * stride_h - padding_h
off_x_w = off_y_w * stride_w - padding_w
off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
off_x_crs = tl.arange(0, BLOCK_K)

CRS = IN_C * KERNEL_H * KERNEL_W
# load inc ptr of x, upade x_ptrs
if not CONV1X1_NHWC:
delta_x_ptrs = delta_x_ptr + off_x_crs
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS)
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
else:
x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]

mask_x = (
(off_x_n < BATCH)
& (off_x_h >= 0)
& (off_x_h < IN_H)
& (off_x_w >= 0)
& (off_x_w < IN_W)
)[:, None] & (off_x_crs < CRS)[None, :]

# offset for the inital ptr for w
off_w_crs = tl.arange(0, BLOCK_K)
off_w_k = off_y_k
w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]

# ------ load x ------
matrix_x = tl.load(x_ptrs, mask=mask_x)
# ------ load w ------
matrix_w = tl.load(w_ptrs, mask=mask_w)

# -----------------------------------------------------------
# allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K):

# ------ matrix multiplication ------
acc += tl.dot(matrix_x, matrix_w)
# ------ update ptrs ------
w_ptrs += BLOCK_K
# load inc ptr of x, upade x_ptrs
if not CONV1X1_NHWC:
delta_x_ptrs += BLOCK_K
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS)
x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
else:
off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
x_ptrs += BLOCK_K

mask_x = (
(off_x_n < BATCH)
& (off_x_h >= 0)
& (off_x_h < IN_H)
& (off_x_w >= 0)
& (off_x_w < IN_W)
)[:, None] & (off_x_crs < CRS)[None, :]
mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
# ------ prefetch ------
# ------ load x ------
matrix_x = tl.load(x_ptrs, mask=mask_x)
# ------ load w ------
matrix_w = tl.load(w_ptrs, mask=mask_w)

{{ pointwise_computation }}
{#
z = tl.load(z_ptrs, mask=mask_z)
acc += z
#}

# add bias if is not None
if bias is not None:
off_bias_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
bias_ptrs = bias + off_bias_k * stride_biasn
mask_bias = off_bias_k < KERNEL_N
_bias = tl.load(bias_ptrs, mask=mask_bias)
acc += _bias[None, :]


acc = acc.to(y.dtype.element_ty)

# rematerialize -- this saves some registers
# offset for output y
off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
off_y_n = off_y_nhw // (OUT_H * OUT_W)
off_y_hw = off_y_nhw % (OUT_H * OUT_W)
# consider output padding
off_y_h = off_y_hw // OUT_W + output_padding_h
off_y_w = off_y_hw % OUT_W + output_padding_w

# y ptrs in the block of [BLOCK_M, BLOCK_N]
y_ptrs = (
y
+ off_y_n[:, None] * stride_yn
+ off_y_h[:, None] * stride_yh
+ off_y_w[:, None] * stride_yw
+ off_y_k[None, :] * stride_yc
)

# out-of-bounds check
mask_y = (
(off_y_n < BATCH)[:, None]
& (off_y_h < OUT_H + output_padding_h)[:, None]
& (off_y_w < OUT_W + output_padding_w)[:, None]
& (off_y_k < KERNEL_N)[None, :]
)

tl.store(y_ptrs, acc, mask=mask_y)

return

0 comments on commit 20e8e04

Please sign in to comment.