Skip to content

Commit

Permalink
test only smaller block_k for mm_plus_mm (#96385)
Browse files Browse the repository at this point in the history
Trim number of tested mm_plus_mm configs to work around triton-lang/triton#1298

Pull Request resolved: pytorch/pytorch#96385
Approved by: https://github.com/bertmaher, https://github.com/jansel
  • Loading branch information
ngimel authored and cyyever committed Mar 12, 2023
1 parent 29a6c51 commit 3892cfd
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions torch/_inductor/kernel/mm_plus_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def ref_mm_plus_mm(a, b, c, d, out):
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# Splitting this into two loops causes an internal triton LLVM error
# https://github.com/openai/triton/issues/967
# for k2 in range(K2, 0, -BLOCK_K):
k2 = k1
for k2 in range(K1, 0, -BLOCK_K):
# Second matmul with C @ D
if EVEN_K:
Expand All @@ -92,9 +89,6 @@ def ref_mm_plus_mm(a, b, c, d, out):
C += BLOCK_K * stride_ck
D += BLOCK_K * stride_dk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
Expand Down Expand Up @@ -163,12 +157,15 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
choices = [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout)]
if use_triton_template(layout):
for config in mm_configs():
choices.append(
mm_plus_mm_template.generate(
(mat1, mat2, mat3, mat4),
layout,
**mm_options(config, k, layout),
# see https://github.com/openai/triton/issues/1298
# BLOCK_K = K causes llvm error
if config.kwargs["BLOCK_K"] < k:
choices.append(
mm_plus_mm_template.generate(
(mat1, mat2, mat3, mat4),
layout,
**mm_options(config, k, layout),
)
)
)

return autotune_select_algorithm(choices, [mat1, mat2, mat3, mat4], layout)

0 comments on commit 3892cfd

Please sign in to comment.