diff --git a/3rdparty/tvm b/3rdparty/tvm index 4a2e00f86..a12155db8 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 4a2e00f862891305098aabe07ef1a395cd4e4f7d +Subproject commit a12155db8eb818f54f27c5124ad8332918cae0ec diff --git a/benchmark/tilelang/benchmark.sh b/benchmark/tilelang/benchmark.sh new file mode 100644 index 000000000..0dced97a7 --- /dev/null +++ b/benchmark/tilelang/benchmark.sh @@ -0,0 +1,7 @@ +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 2048 --n 2048 --k 2048 2>&1 | tee run_gemm_tilelang_2048_2048_2048.log +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 4096 --n 4096 --k 4096 2>&1 | tee run_gemm_tilelang_4096_4096_4096.log +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 8192 2>&1 | tee run_gemm_tilelang_8192_8192_8192.log +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 16384 --n 16384 --k 16384 2>&1 | tee run_gemm_tilelang_16384_16384_16384.log +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 1024 2>&1 | tee run_gemm_tilelang_8192_8192_1024.log +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 2048 2>&1 | tee run_gemm_tilelang_8192_8192_2048.log +python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 4096 2>&1 | tee run_gemm_tilelang_8192_8192_4096.log diff --git a/benchmark/tilelang/benchmark_tilelang_matmul.py b/benchmark/tilelang/benchmark_tilelang_matmul.py new file mode 100644 index 000000000..e9590116a --- /dev/null +++ b/benchmark/tilelang/benchmark_tilelang_matmul.py @@ -0,0 +1,91 @@ +import argparse +from tvm import tl +import tvm.tl.language as T +from tvm.tl.autotuner import * +import itertools + + +def ref_program(A, B): + return A @ B.T + + +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + thread_num = [128, 256] + enable_rasteration = [True, False] + _configs = list( + itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasteration)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'block_K': c[2], + 'num_stages': c[3], + 'thread_num': c[4], + 'enable_rasteration': c[5] + } for c in _configs] + return configs + + +def matmul(M, N, K): + + @autotune( + configs=get_configs(), + keys=['block_M', 'block_N', 'block_K', 'num_stages', 'thread_num'], + warmup=3, + rep=5) + @jit( + out_idx=[2], + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=True, + profiler="tvm", + target="hip") + def kernel(block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None): + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main(A: T.Buffer((M, K), dtype), B: T.Buffer((N, K), dtype), C: T.Buffer((M, N), + dtype)): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + return kernel() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--m', type=int, default=8192, help='M') + parser.add_argument('--n', type=int, default=8192, help='N') + parser.add_argument('--k', type=int, default=8192, help='K') + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + total_flops = 2 * M * N * K + best_latency, best_config, ref_latency = matmul(M, N, K) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") diff --git a/benchmark/tilelang/benchmark_tilelang_mha.py b/benchmark/tilelang/benchmark_tilelang_mha.py new file mode 100644 index 000000000..a6c927e50 --- /dev/null +++ b/benchmark/tilelang/benchmark_tilelang_mha.py @@ -0,0 +1,156 @@ +import argparse +import torch +from tvm import tl +import tvm.tl.language as T +from tvm.tl.autotuner import * +from functools import partial +import itertools + + +def get_configs(): + block_M = [32, 64, 128] + block_N = [32, 64, 128] + num_stages = [0, 1, 2] + thread_num = [128, 256] + _configs = list(itertools.product(block_M, block_N, num_stages, thread_num)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'num_stages': c[2], + 'thread_num': c[3] + } for c in _configs] + return configs + + +def ref_program(Q, K, Vt, casual): + import torch.nn.functional as F + dim = Q.size(-1) + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if casual: + mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1)), diagonal=1).bool().cuda() + scores.masked_fill_(mask, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum('bhqk,bdhk->bqhd', attention_weights, Vt) + return output + + +def flashattn(batch, heads, seq_len, dim, is_casual): + + @autotune( + configs=get_configs(), + keys=['block_M', 'block_N', 'num_stages', 'thread_num'], + warmup=10, + rep=5) + @jit( + out_idx=[3], + supply_type=tl.TensorSupplyType.Normal, + ref_prog=partial(ref_program, casual=is_casual), + rtol=0.01, + atol=0.01, + target="hip") + def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + vt_shape = [batch, dim, heads, seq_len] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), # type: ignore + K: T.Buffer(shape, dtype), # type: ignore + Vt: T.Buffer(vt_shape, dtype), # type: ignore + Output: T.Buffer(shape, dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + Vt_shared = T.alloc_shared([dim, block_N], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_shared = T.alloc_shared([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + # T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = ( + T.ceildiv( + (bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm( + Q_shared, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + T.copy(Vt[bz, :, by, k * block_N:(k + 1) * block_N], Vt_shared) + for i, j in T.Parallel(block_M, dim): + acc_s[i, j] *= scale + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + # T.copy(acc_s, acc_s_cast) + T.copy(acc_s, acc_s_shared) + T.copy(acc_s_shared, acc_s_cast) + T.gemm( + acc_s_cast, + Vt_shared, + acc_o, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + + return main + + return kernel() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=1, help='Batch size') + parser.add_argument('--h', type=int, default=12, help='Number of heads') + parser.add_argument('--n_ctx', type=int, default=2048, help='Context size') + parser.add_argument('--d_head', type=int, default=128, help='Head dimension') + parser.add_argument('--casual', type=bool, default=False, help='Casual flag') + args = parser.parse_args() + BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head + casual = args.casual + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if casual: + total_flops *= 0.5 + + best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") diff --git a/bitblas/tl/mfma_layout.py b/bitblas/tl/mfma_layout.py index a7302e897..2ff9be948 100644 --- a/bitblas/tl/mfma_layout.py +++ b/bitblas/tl/mfma_layout.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import os from tvm.runtime import convert @@ -62,19 +61,13 @@ def shared_16x16_to_local_64x4_layout_B(i, j): return thread_id, local -def thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id): - i = thread_id % 16 - j = local_id + (thread_id // 16) * 4 - return j, i - - -def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id): - # This is a hacky implementation to simulate the performance - is_smooth = os.environ.get("TILE_LANG_SMOOTH_LAYOUT") == "1" - print(is_smooth) - if is_smooth: - return thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id) - +def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id): i = local_id + (thread_id // 16) * 4 j = thread_id % 16 return i, j + + +def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id): + i = thread_id % 16 + j = local_id + (thread_id // 16) * 4 + return i, j diff --git a/bitblas/tl/mfma_macro_generator.py b/bitblas/tl/mfma_macro_generator.py index cd95ed794..c68fce701 100644 --- a/bitblas/tl/mfma_macro_generator.py +++ b/bitblas/tl/mfma_macro_generator.py @@ -30,6 +30,8 @@ class MatrixCoreIntrinEmitter(object): "e5m2_float8": "e5m2", } + is_m_first = False + def __init__( self, a_dtype="float16", @@ -143,12 +145,34 @@ def get_ldmatrix_index_map(self, is_b=False): return index_map, reverse_index_map - def extract_thread_binding(self, thread_id) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: + def extract_thread_binding(self, + thread_id, + is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: + ''' + is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) + which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] + Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] + ''' WARP_SIZE = self.WARP_SIZE block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps - return thread_id % WARP_SIZE, (thread_id // WARP_SIZE) % block_col_warps, ( - thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps + + # if is_m_first is None, then use the default value + if is_m_first is None: + is_m_first = self.is_m_first + + if is_m_first: + lane_id, warp_n, warp_m = thread_id % WARP_SIZE, ( + thread_id // + WARP_SIZE) % block_col_warps, (thread_id // + (WARP_SIZE * block_col_warps)) % block_row_warps, + return lane_id, warp_n, warp_m + else: + lane_id, warp_m, warp_n = thread_id % WARP_SIZE, ( + thread_id // + WARP_SIZE) % block_row_warps, (thread_id // + (WARP_SIZE * block_row_warps)) % block_col_warps, + return lane_id, warp_n, warp_m def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): warp_row_tiles = self.warp_row_tiles @@ -169,19 +193,19 @@ def _warp_ldmatrix_a( thread_bindings, rk=0, ): - tx, _, tz = self.extract_thread_binding(thread_bindings) + tx, _, warp_m = self.extract_thread_binding(thread_bindings) if is_transposed: for i in T.serial(warp_rows): for local_id in T.vectorized(local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * micro_size_k, - tz * warp_row_tiles + i * micro_size_x) + warp_m * warp_row_tiles + i * micro_size_x) A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] else: for i in T.serial(warp_rows): for local_id in T.vectorized(local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = (tz * warp_row_tiles + i * micro_size_x, + l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k) A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] @@ -206,14 +230,14 @@ def _warp_ldmatrix_b( thread_bindings, rk=0, ): - tx, ty, _ = self.extract_thread_binding(thread_bindings) + tx, warp_n, _ = self.extract_thread_binding(thread_bindings) if is_transposed: for j in T.serial(warp_cols): for local_id in T.vectorized(local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = ( - ty * warp_col_tiles + j * micro_size_y, + warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * micro_size_k, ) B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col] @@ -223,7 +247,7 @@ def _warp_ldmatrix_b( row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = ( rk * chunk + ki * micro_size_k, - ty * warp_col_tiles + j * micro_size_y, + warp_n * warp_col_tiles + j * micro_size_y, ) B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col] @@ -251,10 +275,10 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): compute_a_dtype, compute_b_dtype, compute_out_dtype, - A_local_buf.data, - (i * local_size_a) // local_size_a, B_local_buf.data, (j * local_size_b) // local_size_b, + A_local_buf.data, + (i * local_size_a) // local_size_a, C_local_buf.data, (i * warp_cols * local_size_out + j * local_size_out) // local_size_out, dtype=compute_out_dtype, @@ -280,22 +304,22 @@ def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): # equal to the warp_size @T.macro def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings): - tx, ty, tz = self.extract_thread_binding(thread_bindings) + tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.serial(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) - C_buf[tz * warp_rows + i, ty * warp_cols + j, row, + C_buf[warp_m * warp_rows + i, warp_n * warp_cols + j, row, col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] @T.macro def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings): - tx, ty, tz = self.extract_thread_binding(thread_bindings) + tx, warp_n, warp_m = self.extract_thread_binding(thread_bindings) for i, j in T.grid(warp_rows, warp_cols): for local_id in T.serial(local_size_out): row, col = T.meta_var(mfma_store_index_map(tx, local_id)) - C_buf[(pid_m * BLOCK_M + tz * warp_rows + i) * M_DIM + row, - (pid_n * BLOCK_N + ty * warp_cols + j) * N_DIM + + C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, + (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 9f354f8ce..e11dab3e7 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -12,8 +12,7 @@ ldmatrix_16x32_to_shared_16x32_layout_b, mma_store_32x8_to_shared_16x16_layout, ) -from .mfma_layout import ( - thread_id_shared_access_64x4_to_16x16_layout_C,) +from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): @@ -113,7 +112,7 @@ def mma_store_index_map(*args, **kwargs): def mfma_store_index_map(*args, **kwargs): - return thread_id_shared_access_64x4_to_16x16_layout_C(*args, **kwargs) + return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(*args, **kwargs) def get_mma_micro_size(dtype: Literal["float16", "int8"]): diff --git a/integration/ComposableKernel/test_mfma_a_layout.py b/integration/ComposableKernel/test_mfma_a_layout.py new file mode 100644 index 000000000..638411e86 --- /dev/null +++ b/integration/ComposableKernel/test_mfma_a_layout.py @@ -0,0 +1,99 @@ +import subprocess + +layouts = [ + [False, False, False, False, False], + [False, False, False, False, True], + [False, False, False, True, False], + [False, False, False, True, True], + [False, False, True, False, False], + [False, False, True, False, True], + [False, False, True, True, False], + [False, False, True, True, True], + [False, True, False, False, False], + [False, True, False, False, True], + [False, True, False, True, False], + [False, True, False, True, True], + [False, True, True, False, False], + [False, True, True, False, True], + [False, True, True, True, False], + [False, True, True, True, True], + [True, False, False, False, False], + [True, False, False, False, True], + [True, False, False, True, False], + [True, False, False, True, True], + [True, False, True, False, False], + [True, False, True, False, True], + [True, False, True, True, False], + [True, False, True, True, True], + [True, True, False, False, False], + [True, True, False, False, True], + [True, True, False, True, False], + [True, True, False, True, True], + [True, True, True, False, False], + [True, True, True, False, True], + [True, True, True, True, False], + [True, True, True, True, True], +] + +raw_func = '''Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k, + const int warp_m, const int warp_n) { + // assume not transposed + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0); + ICHECK(block_k % 16 == 0); + auto base_layout = makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n); + auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false); + return block_layout; +}''' +file_path = "/home/aiscuser/lei/BitBLAS/3rdparty/tvm/src/tl/layout/gemm_layouts.cc" + +for layout in layouts: + base_layout_0 = "false" if not layout[0] else "true" + base_layout_1 = "false" if not layout[1] else "true" + block_layout_0 = "false" if not layout[2] else "true" + block_layout_1 = "false" if not layout[3] else "true" + warp_layout_0 = "false" if not layout[4] else "true" + + log_path = f"base_{base_layout_0}_{base_layout_1}_warp_{warp_layout_0}_block_{block_layout_0}_{block_layout_1}.log" + + new_func = raw_func.replace( + "auto base_layout = makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);", + f"auto base_layout = makeGemmFragmentAB16x16CDNA()->Repeat({{1, 1}}, {base_layout_0}, {base_layout_1});" + ).replace( + "auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n);", + f"auto warp_layout = base_layout->Repeat({{warp_m / 16, block_k / 16}}, {block_layout_0}, {block_layout_0});" + ).replace( + "auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);", + f"auto block_layout = warp_layout->Repeat({{block_m / warp_m, 1}}, {warp_layout_0})->Replicate(block_n / warp_n);" + ) + + print(new_func) + with open(file_path, "r") as f: + content = f.read() + content = content.replace(raw_func, new_func) + with open(file_path, "w") as f: + f.write(content) + + with open(log_path, "w") as log_file: + # build tvm + subprocess.run(["make", "-j8"], + cwd="/home/aiscuser/lei/BitBLAS/3rdparty/tvm/build", + stdout=log_file, + stderr=log_file) + + # Execute Test log + subprocess.run([ + "python", + "/home/aiscuser/lei/BitBLAS/integration/ComposableKernel/test_mfma_fragement_gemm.py" + ], + cwd="/home/aiscuser/lei/BitBLAS/integration/ComposableKernel", + stdout=log_file, + stderr=log_file) + + # Recover + content = content.replace(new_func, raw_func) + + with open(file_path, "w") as f: + f.write(content) diff --git a/integration/ComposableKernel/test_mfma_c_layout.py b/integration/ComposableKernel/test_mfma_c_layout.py new file mode 100644 index 000000000..3e5a433e7 --- /dev/null +++ b/integration/ComposableKernel/test_mfma_c_layout.py @@ -0,0 +1,84 @@ +import subprocess + +layouts = [ + [False, False, False, False], + [False, False, False, True], + [False, False, True, False], + [False, False, True, True], + [False, True, False, False], + [False, True, False, True], + [False, True, True, False], + [False, True, True, True], + [True, False, False, False], + [True, False, False, True], + [True, False, True, False], + [True, False, True, True], + [True, True, False, False], + [True, True, False, True], + [True, True, True, False], + [True, True, True, True], +] + +raw_func = '''Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n, + const int element_size) { + if (element_size == 64) LOG(FATAL) << "Not supported"; + ICHECK(block_m % warp_m == 0); + ICHECK(block_n % warp_n == 0); + ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; + auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, false, false); + auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 16}, true, true); + return block_layout; +}''' +file_path = "/home/aiscuser/lei/BitBLAS/3rdparty/tvm/src/tl/layout/gemm_layouts.cc" + +for layout in layouts: + block_layout_0 = "false" if not layout[0] else "true" + block_layout_1 = "false" if not layout[1] else "true" + warp_layout_0 = "false" if not layout[2] else "true" + warp_layout_1 = "false" if not layout[3] else "true" + + log_path = f"block_{block_layout_0}_{block_layout_1}_warp_{warp_layout_0}_{warp_layout_1}.log" + + # new_func = raw_func.replace( + # "base_layout->Repeat({block_m / warp_m, block_n / warp_n}, false, false);", + # f"base_layout->Repeat({{block_m / warp_m, block_n / warp_n}}, {block_layout_0}, {block_layout_1});" + # ).replace( + # "warp_layout->Repeat({warp_m / 16, warp_n / 16}, true, true);", + # f"warp_layout->Repeat({{warp_m / 16, warp_n / 16}}, {warp_layout_0}, {warp_layout_1});") + + new_func = raw_func.replace( + "base_layout->Repeat({block_m / warp_m, block_n / warp_n}, false, false);", + f"base_layout->Repeat({{warp_m / 16, warp_n / 16}}, {block_layout_0}, {block_layout_1});" + ).replace( + "warp_layout->Repeat({warp_m / 16, warp_n / 16}, true, true);", + f"warp_layout->Repeat({{block_m / warp_m, block_n / warp_n}}, {warp_layout_0}, {warp_layout_1});" + ) + print(new_func) + with open(file_path, "r") as f: + content = f.read() + content = content.replace(raw_func, new_func) + with open(file_path, "w") as f: + f.write(content) + + with open(log_path, "w") as log_file: + # build tvm + subprocess.run(["make", "-j8"], + cwd="/home/aiscuser/lei/BitBLAS/3rdparty/tvm/build", + stdout=log_file, + stderr=log_file) + + # Execute Test log + subprocess.run([ + "python", "/home/aiscuser/lei/BitBLAS/integration/ComposableKernel/test_block_gemm.py" + ], + cwd="/home/aiscuser/lei/BitBLAS/integration/ComposableKernel", + stdout=log_file, + stderr=log_file) + + # Recover + content = content.replace(new_func, raw_func) + + with open(file_path, "w") as f: + f.write(content) diff --git a/integration/ComposableKernel/test_mfma_fragement_gemm.py b/integration/ComposableKernel/test_mfma_fragement_gemm.py new file mode 100644 index 000000000..8e2e4b169 --- /dev/null +++ b/integration/ComposableKernel/test_mfma_fragement_gemm.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.ops.base_scheduler import simplify_prim_func + + +def make_pad_layout(shared_buf, pad_offset=4): + shape = shared_buf.shape + stride = shape[-1] + + def transform(i, j): + idx = i * (stride + pad_offset) + j + return idx + + return T.Layout(shape, transform) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + # This is a debug config + # block_row_warps = 2 + # block_col_warps = 2 + # warp_row_tiles = 64 + # warp_col_tiles = 64 + # chunk = 32 + + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 + + # shared_scope = "shared.dyn" + shared_scope = "shared" + + # Pipeline Stage + stage = 1 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + C_shape = (M, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = (block_M, block_N) + + warp_size = 64 + threads = warp_size * (block_row_warps * block_col_warps) + + @T.prim_func + def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + A_local = T.alloc_fragment(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_local = T.alloc_fragment(C_shared_shape, accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.copy(A_shared, A_local) + T.gemm(A_local, B_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + print(matmul) + mod, params = TL.lower(matmul, target="hip") + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + torch.random.manual_seed(0) + if in_dtype == "int8": + A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones((N, K), device="cuda", dtype=getattr(torch, in_dtype)) + print(f"{A=}") + print(f"{B=}") + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + # latency = mod.do_bench(mod.func, warmup=5, rep=10) + + # # Ensure that the latency is not None + # assert latency is not None + # print(f"{latency=}") + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + assert_tl_matmul_correctness(256, 256, 256, "float16", "float32", "float32")