Skip to content

Commit

Permalink
Add OptimizeEpilogue pass. (#346)
Browse files Browse the repository at this point in the history
* optimize_epilogue

* Add config

* Remove licenses

* Comment out Hopper specific parameters when printing out configs

* Add benchmark parameters from flash-attention repo

* Add Z and H in the key of autotuner

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
  • Loading branch information
oplavsic and zhanglx13 committed Nov 3, 2023
1 parent cb02a0b commit c65f1e6
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 46 deletions.
37 changes: 9 additions & 28 deletions lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp
@@ -1,26 +1,3 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -70,12 +47,16 @@ class BypassEpilogueSMEM : public mlir::RewritePattern {
if (!cvtOp)
return mlir::failure();

if (!cvtOp.getSrc()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
auto encoding =
cvtOp.getSrc().getType().cast<RankedTensorType>().getEncoding();

#ifdef USE_ROCM
if (!encoding.isa<triton::gpu::MfmaEncodingAttr>())
return mlir::failure();
#else
if (!encoding.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
#endif

if (!cvtOp.getResult().hasOneUse())
return mlir::failure();
Expand Down
4 changes: 2 additions & 2 deletions python/perf-kernels/06-fused-attention-transV.py
Expand Up @@ -86,7 +86,7 @@ def _attn_fwd_inner(
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)


Expand Down Expand Up @@ -547,7 +547,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
)

## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)

Expand Down
4 changes: 4 additions & 0 deletions python/src/triton.cc
Expand Up @@ -1870,6 +1870,10 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
})
.def("add_tritongpu_optimize_epilogue_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUOptimizeEpiloguePass());
})
.def("add_tritongpu_reorder_instructions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
Expand Down
9 changes: 5 additions & 4 deletions python/triton/runtime/autotuner.py
Expand Up @@ -208,11 +208,12 @@ def __str__(self):
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_ctas: {self.num_ctas}')
## Comment out Hopper specific parameters
#res.append(f'num_ctas: {self.num_ctas}')
res.append(f'num_stages: {self.num_stages}')
res.append(
f'enable_warp_specialization: {self.enable_warp_specialization}')
res.append(f'enable_persistent: {self.enable_persistent}')
#res.append(
# f'enable_warp_specialization: {self.enable_warp_specialization}')
#res.append(f'enable_persistent: {self.enable_persistent}')
return ', '.join(res)


Expand Down
32 changes: 20 additions & 12 deletions python/tutorials/06-fused-attention.py
Expand Up @@ -84,11 +84,12 @@ def _attn_fwd_inner(
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)


Expand Down Expand Up @@ -549,7 +550,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
)

## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)

Expand Down Expand Up @@ -730,28 +731,35 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None

BATCH, N_HEADS, N_CTX= 4, 48, 4096
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd', 'bwd']:
for causal in [False, True]:
if mode == 'bwd' and causal == False:
for D_HEAD in [128, 64]:
if mode == 'bwd' and D_HEAD == 128:
continue
for D_HEAD in [64, 128]:
if mode == 'bwd' and D_HEAD == 128:
for causal in [False, True]:
if mode == 'bwd' and causal == False:
continue
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 15)],
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
(1, 16, 16384),
(4, 48, 1024),
(4, 48, 2048),
(4, 48, 4096),
(4, 48, 8192),
(4, 48, 16384),
],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
Expand Down

0 comments on commit c65f1e6

Please sign in to comment.