Skip to content

Commit

Permalink
[Operator] Add sm75 support for attention (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjq committed May 30, 2023
1 parent c1cb37b commit 7e2f16f
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 22 deletions.
100 changes: 89 additions & 11 deletions python/hidet/graph/ops/definitions/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from hidet.lang.cuda import MmaConfig, mma_sync, cp_async, ldmatrix, cp_async_wait_all
from hidet.graph.ops.definitions.utils import Task, Operator, Tensor, TensorNode, compute, input_like
from hidet.graph.ops.definitions.utils import broadcast_shape, broadcast_shapes, broadcast_indices
from hidet.graph.ops.definitions.utils import can_broadcast
from hidet.graph.ops.definitions.utils import can_broadcast, schedule_utils
from hidet.utils import same_list
from hidet.utils.py import cdiv, prod
from .attention_mask import AttnMaskAddOp
Expand Down Expand Up @@ -121,7 +121,7 @@ def implement_cuda(self, working_dir: str) -> Union[List[IRModule], IRModule]:
@tune.space(1, 'warp_elems_m', [16])
@tune.space(1, 'warp_elems_n', [128])
@tune.space(1, 'warp_elems_k', [32])
@tune.space(1, 'mma_config', [MmaConfig.m16n8k16_f16_f16()])
@tune.space(1, 'mma_config', [MmaConfig.m16n8k8_f16_f16()])
def cuda_schedule_attn(
self,
block_i=128,
Expand All @@ -141,6 +141,12 @@ def calc_swizzle_size(d):
return n, d // n
return -1, -1

compute_capability = hidet.cuda.compute_capability()
compute_capability = compute_capability[0] * 10 + compute_capability[1]
if compute_capability < 80:
# hack: sm75 only supports m16n8k8, not m16n8k16
tune.check(mma_config.k == 8)

task = self
is_causal = task.attrs['is_causal']
node_q, node_k, node_v, node_o = task.inputs[0], task.inputs[1], task.inputs[2], task.outputs[0]
Expand Down Expand Up @@ -265,7 +271,9 @@ def calc_swizzle_size(d):
+ smem_bytes_mij
)
used_smem_bytes_per_block = dynamic_smem_bytes
tune.check(used_smem_bytes_per_block <= 99000)
smem_limits = {70: 96000, 72: 96000, 75: 64000, 80: 163000, 86: 99000, 87: 163000, 89: 99000, 90: 227000}
max_smem = 99000 if compute_capability > 90 else smem_limits[compute_capability]
tune.check(used_smem_bytes_per_block <= max_smem)

smem_l_type = tensor_type(sm_dtype, shape=[i_rows_per_tb])
smem_m_type = tensor_type(sm_dtype, shape=[i_rows_per_tb])
Expand Down Expand Up @@ -326,6 +334,16 @@ def calc_swizzle_size(d):
t_per_block_k_8_floor, block_j_o // 8
)

q_g2s_layout_sm75, _ = schedule_utils.get_transfer_task_map(
task_shape=[block_i, dpad_size], num_workers=min(block_i * dpad_size, block_size), ranks=[0, 1]
)
k_g2s_layout_sm75, _ = schedule_utils.get_transfer_task_map(
task_shape=[block_k, block_j], num_workers=min(block_k * block_j, block_size), ranks=[0, 1]
)
v_g2s_layout_sm75, _ = schedule_utils.get_transfer_task_map(
task_shape=[block_k_o, block_j_o], num_workers=min(block_k_o, block_j_o, block_size), ranks=[0, 1]
)

with hidet.script_module() as module:
# --------------- helper functions ---------------------------------------------------------------------
@hidet.script
Expand All @@ -352,6 +370,11 @@ def resolve_ldmatrix(regs: ~f16, smem_addr: ~f16, is_A: hidet.lang.boolean):
b32_regs = view(regs, u32[1])
ldmatrix(regs=[b32_regs[0]], smem_addr=smem_addr, trans=True)

@hidet.script
def cp_async_sync():
if compute_capability >= 80:
cp_async_wait_all()

@hidet.script
def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type):
for i in lm_layout.on(threadIdx.x):
Expand All @@ -360,7 +383,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type):
smem_m[i] = smem_m_type.dtype.min_value

@hidet.script
def copy_k_g2s(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32):
def copy_k_g2s_sm80(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:]
for i, j_seg in k_g2s_layout.on(threadIdx.x):
Expand All @@ -370,7 +393,7 @@ def copy_k_g2s(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j:
cp_async(~smem_k[i, j], ~gmem_k[i, j], cp_size=16, src_size=src_size * 2, cache_level='global')

@hidet.script
def copy_v_g2s(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j: i32):
def copy_v_g2s_sm80(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :]
for i, j_seg in v_g2s_layout.on(threadIdx.x):
Expand All @@ -380,7 +403,7 @@ def copy_v_g2s(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j:
cp_async(~smem_v[i, j], ~gmem_v[i, j], cp_size=16, src_size=src_size * 2, cache_level='global')

@hidet.script
def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :]
for i, j_seg in q_g2s_layout.on(threadIdx.x):
Expand All @@ -389,6 +412,60 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i:
if threadIdx.x < q_g2s_layout.num_workers and i < smem_q_type.shape[0]:
cp_async(~smem_q[i, j], ~gmem_q[i, j], cp_size=16, src_size=src_size * 2, cache_level='global')

@hidet.script
def copy_k_g2s_sm75(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:]
for i, j in k_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]:
if offset_k + i < d_size and offset_j + j < n_size:
smem_k[i, j] = gmem_k.read([i, j], protected=False)
else:
smem_k[i, j] = f16.zero

@hidet.script
def copy_v_g2s_sm75(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :]
for i, j in v_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]:
if offset_j + i < n_size and j < d_size:
smem_v[i, j] = gmem_v.read([i, j], protected=False)
else:
smem_v[i, j] = f16.zero

@hidet.script
def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :]
for i, j in q_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]:
if offset_i + i < n_size and j < d_size:
smem_q[i, j] = gmem_q.read([i, j], protected=False)
else:
smem_q[i, j] = f16.zero

@hidet.script
def copy_k_g2s(k: f16[k_head + [d_size, n_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32):
if compute_capability >= 80:
copy_k_g2s_sm80(k, smem_k, offset_j, offset_k)
else:
copy_k_g2s_sm75(k, smem_k, offset_j, offset_k)

@hidet.script
def copy_v_g2s(v: f16[v_head + [n_size, d_size]], smem_v: smem_v_type, offset_j: i32):
if compute_capability >= 80:
copy_v_g2s_sm80(v, smem_v, offset_j)
else:
copy_v_g2s_sm75(v, smem_v, offset_j)

@hidet.script
def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
if compute_capability >= 80:
copy_q_g2s_sm80(q, smem_q, offset_i)
else:
copy_q_g2s_sm75(q, smem_q, offset_i)

@hidet.script
def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32):
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
Expand All @@ -402,7 +479,8 @@ def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i:
for ti, tj in mma_config.c_store_map.on(lane_id):
delta_m = wi * warp_elems_m_o + mma_i * mma_m + ti
delta_n = wj * warp_elems_n_o + mma_j * mma_n + tj
gmem_o[delta_m, delta_n] = regs_o[mma_i, mma_j, p]
if delta_m < n_size and delta_n < d_size:
gmem_o[delta_m, delta_n] = regs_o[mma_i, mma_j, p]
p += 1

@hidet.script
Expand Down Expand Up @@ -603,7 +681,7 @@ def attn_kernel(

# Copy first tile of k into shared memory
copy_k_g2s(k, ~smem_k[0, 0, 0], offset_j, 0)
cp_async_wait_all()
cp_async_sync()
syncthreads()

for k0 in range(k_tiles):
Expand All @@ -629,7 +707,7 @@ def attn_kernel(
~regs_k[mma_k % 2, mma_j, 0],
~regs_acc[mma_i, mma_j, 0],
)
cp_async_wait_all()
cp_async_sync()
syncthreads()

# Preload first tile of v into shared memory
Expand Down Expand Up @@ -657,7 +735,7 @@ def attn_kernel(
for a, b, c in grid(mmas_per_warp_m_o, mmas_per_warp_n_o, mma_config.c_elements):
regs_acc_o[a, b, c] = acc_dtype.zero

cp_async_wait_all()
cp_async_sync()
syncthreads()
for k1 in range(k_tiles_o):
# Load Vj into Smem
Expand All @@ -682,7 +760,7 @@ def attn_kernel(
~regs_v[mma_k % 2, mma_j, 0],
~regs_acc_o[mma_i, mma_j, 0],
)
cp_async_wait_all()
cp_async_sync()
syncthreads()
# ----------------------------

Expand Down

0 comments on commit 7e2f16f

Please sign in to comment.