Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds triton flash attention2 kernel #4337

Merged
merged 14 commits into from
Sep 21, 2023
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def lazy_call(self, callback):
def communication_backend_name(self):
...

@abc.abstractmethod
def is_triton_supported(self):
...

# Tensor operations
@property
@abc.abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Data types
def is_bf16_supported(self):
return True
Expand Down
7 changes: 7 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
major, _ = torch.cuda.get_device_capability()
if major >= 8:
return True
else:
return False

# Tensor operations

@property
Expand Down
3 changes: 3 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Tensor operations
@property
def BFloat16Tensor(self):
Expand Down
3 changes: 3 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Tensor operations

@property
Expand Down
188 changes: 180 additions & 8 deletions deepspeed/ops/transformer/inference/triton/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
from deepspeed.accelerator import get_accelerator
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp
Expand Down Expand Up @@ -70,6 +72,9 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count

self.mp_group = mp_group
self.use_flash = False
# triton flash attention is enabled when the compute capability >= 8.0
if get_accelerator().is_triton_supported():
self.use_flash = True

# used for quantization
self.q_scales = q_scales
Expand Down Expand Up @@ -176,7 +181,7 @@ def forward(
qkv = qkv_out[0]

if use_triton_attention and (alibi is None):
context_layer = compute_attention(qkv=qkv,
context_layer = _triton_attention(qkv=qkv,
input_mask=input_mask,
scale=self.scale,
layer_past=layer_past,
Expand Down Expand Up @@ -204,7 +209,7 @@ def forward(
global inference_module


def compute_attention(qkv,
def _triton_attention(qkv,
input_mask,
layer_past,
alibi,
Expand All @@ -217,13 +222,180 @@ def compute_attention(qkv,
if isinstance(qkv, list):
qkv = qkv[0]

#assert layer_past is None, "layer_past not supported in triton yet"
assert alibi is None, "layer_past not supported in alibi yet"
output = score_4d_matmul(qkv, head_size, triangular, scale)
if triangular:
output = softmax(output)

if use_triton_flash:
output = _triton_packed_flash(qkv,
head_size,
input_mask,
scale,
causal=triangular,
add_mask=(not triangular and input_mask is not None))
else:
output = softmax(output, input_mask)
output = context_4d_matmul(output, qkv, head_size)
output = score_4d_matmul(qkv, head_size, triangular, scale)
if triangular:
output = softmax(output)
else:
output = softmax(output, input_mask)
output = context_4d_matmul(output, qkv, head_size)

return output


'''
flash attention 2
modified the triton kernel in
https://github.com/openai/triton/blob/08c16589573621fcb8cd5a9c3b8a0537077f876d/python/tutorials/06-fused-attention.py
'''


@triton.jit
def _flash_packed_kernel(
QKV,
mask,
ADD_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
sm_scale,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_mh,
stride_kz,
stride_kh,
stride_kn,
stride_vz,
stride_vh,
stride_vk,
stride_oz,
stride_oh,
stride_om,
Z,
H,
N_CTX,
P_SEQ,
hidden_size,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
batch = off_hz // H
head = off_hz % H

q_offset = batch * stride_qz + head * BLOCK_DMODEL
k_offset = q_offset + hidden_size
v_offset = k_offset + hidden_size

# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)

q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qh + offs_d[None, :]
k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qh + offs_d[None, :]
v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qh + offs_d[None, :]

# mask
off_mask = batch * stride_mh + offs_n[None, :]
mask_ptrs = mask + off_mask

# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(k_ptrs + start_n * stride_qh, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)
v = tl.load(v_ptrs + start_n * stride_qh, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)

if ADD_MASK:
mask_val = tl.load(mask_ptrs)
mask_ptrs += BLOCK_N
qk = qk + mask_val.to(tl.float32)

if IS_CAUSAL:
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))

qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16)
qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v.to(tl.float16))
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new

# write back l and m
acc = acc / l_i[:, None]
o_offset = batch * stride_oz + head * BLOCK_DMODEL
out_ptrs = Out + o_offset + (offs_m[:, None] * stride_oh + offs_d[None, :])
tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX)


def _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True):
heads = qkv.shape[-1] // 3 // head_size
hidden_size = qkv.shape[-1] // 3

BLOCK_M = 128
BLOCK_N = 64 if head_size <= 64 else 32

o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half)
if mask is None:
mask = torch.empty(0)
add_mask = False

grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1)
num_stages = 4 if head_size <= 64 else 3
num_warps = 4
P_SEQ = 0

_flash_packed_kernel[grid](qkv,
mask,
add_mask,
causal,
sm_scale,
o,
qkv.stride(0),
qkv.stride(1),
qkv.stride(2),
mask.stride(1) if add_mask else 0,
qkv.stride(0),
qkv.stride(1),
qkv.stride(2),
qkv.stride(0),
qkv.stride(1),
qkv.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
qkv.shape[0],
heads,
qkv.shape[1],
P_SEQ,
hidden_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=head_size,
num_warps=num_warps,
num_stages=num_stages)

return o
5 changes: 5 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton):


@pytest.mark.inference
@pytest.mark.parametrize("model_w_task", [
lekurile marked this conversation as resolved.
Show resolved Hide resolved
("bert-base-cased", "fill-mask"),
("roberta-large", "fill-mask"),
],
ids=["bert", "roberta"])
class TestModelTask(DistributedTest):
world_size = 1

Expand Down
11 changes: 11 additions & 0 deletions tests/unit/ops/transformer/inference/inference_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''):
y = y.float()
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)


def max_diff(a, b):
a = a.to(torch.float32).flatten()
b = b.to(torch.float32).flatten()
diff = torch.abs(a - b)
max_diff_indices = torch.argsort(diff)[-1]
print("Max difference indices:", max_diff_indices)
print("Max difference values:", diff[max_diff_indices])
print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}")
return max_diff_indices
Loading
Loading