From f3b1eb9862ebc0ee2bbd38194ace250fa6be9b0b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 28 Sep 2024 07:43:40 +0000 Subject: [PATCH 1/8] Refactor tilelang dequantize module and add matmul_blocked_weight_only function --- .../general_matmul/tilelang/dense/__init__.py | 6 + .../general_matmul/tilelang/dense/matmul.py | 484 ++++++++++++++++++ .../tilelang/dequantize/__init__.py | 2 + .../tilelang/dequantize/matmul_weight_only.py | 110 ++++ .../test_general_matmul_tilelang_kernel.py | 383 ++++++++++++++ .../tilelang/test_tilelang_dequantize_gemm.py | 44 +- 6 files changed, 1007 insertions(+), 22 deletions(-) create mode 100644 bitblas/ops/general_matmul/tilelang/dequantize/__init__.py create mode 100644 bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py create mode 100644 testing/python/operators/test_general_matmul_tilelang_kernel.py diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 03b5a81f3..23cda34db 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -6,3 +6,9 @@ matmul_macro_tensorcore, # noqa: F401 matmul_macro_tensorcore_weight_propagation_level_ldmatrix # noqa: F401 ) + +from .matmul import ( + MatmulScheduler, # noqa: F401 + MatmulFineGrainScheduler, # noqa: F401 + MatmulWeightPropagationScheduler, # noqa: F401 +) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 49858bf2f..f5ae7a648 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -16,6 +16,490 @@ from bitblas.ops.operator import TransformKind +from dataclasses import dataclass + + +@dataclass +class MatmulScheduler: + + # OP Related Config + M: int + N: int + K: int + trans_A: bool = False + trans_B: bool = False + dtypeAB: str = "float16" + dtypeC: str = "float16" + accum_dtype: str = "float16" + + # Default Tile Related Params + block_M: int = 64 + block_N: int = 64 + block_K: int = 32 + num_stages: int = 2 + threads: int = 128 + enable_rasterization: bool = False # Enhance L2 Locality + + def with_default_config(self): + block_M = getattr(self, "block_M", 64) + block_N = getattr(self, "block_N", 64) + block_K = getattr(self, "block_K", 32) + num_stages = getattr(self, "num_stages", 2) + threads = getattr(self, "threads", 128) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + def apply_config( + self, + block_M=64, + block_N=64, + block_K=32, + num_stages=2, + threads=128, + # Enhance L2 Locality + enable_rasterization=False, + ): + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype + + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + if enable_rasterization: + # rasterization factor + T.use_swizzle(10) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + def __post_init__(self): + # Add Config Validation + return + + +@dataclass +class MatmulFineGrainScheduler: + # Fine-grained matrix multiplication scheduler + # Allows for more detailed configuration. + + # Operation Configuration + M: int + N: int + K: int + dtypeAB: str = "float16" + dtypeC: str = "float16" + trans_A: bool = False + trans_B: bool = True + accum_dtype: str = "float16" + + # Tensor Core Warp Configuration + block_row_warps: int = 2 + block_col_warps: int = 2 + warp_row_tiles: int = 32 + warp_col_tiles: int = 32 + chunk: int = 32 # Usually determines the K-dimension split size + + # Tiling and Other Optimization Parameters + num_stages: int = 2 + enable_rasterization: bool = False + + def with_default_config(self): + block_row_warps = getattr(self, "block_row_warps", 2) + block_col_warps = getattr(self, "block_col_warps", 2) + warp_row_tiles = getattr(self, "warp_row_tiles", 32) + warp_col_tiles = getattr(self, "warp_col_tiles", 32) + chunk = getattr(self, "chunk", 32) + num_stages = getattr(self, "num_stages", 2) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + def apply_config(self, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=16, + num_stages=2, + enable_rasterization=False): + + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype + + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + # Define the shapes of matrices and shared memory buffers + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, + micro_size_y) + + # GPU warp configuration for NVIDIA GPUs + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + + # Calculate local fragment sizes for tensor core + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" + + # Configure the tensor core intrinsic emitter + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk) + + # Define the main kernel using the generated configuration + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + # Grid and thread configuration for CUDA kernel + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + # Allocate shared memory and local fragments + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + + # Thread-level parallelism for Tensor Cores + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + # Apply memory layout optimizations + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Optional rasterization for L2 locality enhancement + if enable_rasterization: + T.use_swizzle(panel_size=10) + + # Initialize accumulation buffer to zero + T.clear(C_local) + + # Main matrix multiplication pipeline with multiple stages + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + # Load A matrix into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B matrix into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_local, B_local, C_local) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main + + def __post_init__(self): + # Validate the matrix transpose settings + assert self.trans_A is False, "Currently only support Matrix A not transposed" + assert self.trans_B is True, "Currently only support Matrix B transposed" + + return + + +@dataclass +class MatmulWeightPropagationScheduler: + # Fine-grained matrix multiplication scheduler + # Allows for more detailed configuration. + + # Operation Configuration + M: int + N: int + K: int + dtypeAB: str = "float16" + dtypeC: str = "float16" + trans_A: bool = False + trans_B: bool = True + accum_dtype: str = "float16" + + # Tensor Core Warp Configuration + block_row_warps: int = 2 + block_col_warps: int = 2 + warp_row_tiles: int = 32 + warp_col_tiles: int = 32 + chunk: int = 32 # Usually determines the K-dimension split size + + # Tiling and Other Optimization Parameters + num_stages: int = 2 + enable_rasterization: bool = False + + def with_default_config(self): + block_row_warps = getattr(self, "block_row_warps", 2) + block_col_warps = getattr(self, "block_col_warps", 2) + warp_row_tiles = getattr(self, "warp_row_tiles", 4) + warp_col_tiles = getattr(self, "warp_col_tiles", 4) + chunk = getattr(self, "chunk", 16) + num_stages = getattr(self, "num_stages", 2) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + def apply_config(self, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=16, + num_stages=2, + enable_rasterization=False): + + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype + + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + # TODO(lei): Can be generalized to analyzed from bank size + pad_factor = 8 if dtypeAB == "float16" else 16 + + can_swizzle_a = block_K * DataType(dtypeAB).bits == 512 + apply_pad_a = not can_swizzle_a + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + + # Define the shapes of matrices and shared memory buffers + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, + micro_size_k) + C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, + micro_size_y) + + # GPU warp configuration for NVIDIA GPUs + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + + # Calculate local fragment sizes for tensor core + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" + + # Configure the tensor core intrinsic emitter + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=TransformKind.LDMatrixTransform, + ) + + # Define the main kernel using the generated configuration + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + # Grid and thread configuration for CUDA kernel + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + # Allocate shared memory and local fragments + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + + # Thread-level parallelism for Tensor Cores + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + # Apply memory layout optimizations + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Optional rasterization for L2 locality enhancement + if enable_rasterization: + T.use_swizzle(panel_size=10) + + # Initialize accumulation buffer to zero + T.clear(C_local) + + # Main matrix multiplication pipeline with multiple stages + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + # Load A matrix into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B matrix into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_local, B_local, C_local) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main + + def __post_init__(self): + # Validate the matrix transpose settings + assert self.trans_A is False, "Currently only support Matrix A not transposed" + assert self.trans_B is True, "Currently only support Matrix B transposed" + + return + def matmul_blocked( M, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py new file mode 100644 index 000000000..59e481eb9 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py new file mode 100644 index 000000000..0bb0e3ce2 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T + +from bitblas.tl.utils import ( + get_mma_micro_size, + make_swizzle_layout, +) + +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) + +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + +from bitblas.ops.operator import TransformKind + +# TODO(lei): Implement A General Matmul Emitter for Dequantize + +def matmul_blocked_weight_only( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + # Tile Related Params + block_M=64, + block_N=64, + block_K=32, + num_stages=2, + threads=128, + enable_rasterization=False, # Enhance L2 Locality +): + num_elems_per_byte = 8 // bit + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + import tvm.tl.language as T + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment([8], storage_dtype, "local") + B_dequantize_local = T.alloc_fragment([16], in_dtype, "local") + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + if enable_rasterization: + # rasterization factor + T.use_swizzle(10) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + + for i, j in T.Parallel(block_N, block_K // num_elems_per_byte): + B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j] + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): + for v in T.vectorized(0, 4): + vi = (i * threads * 4 + tx * 4 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 4 + tx * 4 + v) % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, 8): + B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( + bit, + B_local[v // 2], + v % 2, + dtype=in_dtype, + ) + for v in T.vectorized(0, 8): + vi = (i * threads * 8 + tx * 8 + v) // (block_K) + vj = (i * threads * 8 + tx * 8 + v) % (block_K) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main \ No newline at end of file diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py new file mode 100644 index 000000000..2ca273560 --- /dev/null +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -0,0 +1,383 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl +from bitblas.ops.general_matmul.tilelang.dense.matmul import ( + MatmulScheduler, + MatmulFineGrainScheduler, + MatmulWeightPropagationScheduler, +) + +import torch +import torch.backends + +torch.manual_seed(0) + + +def assert_matmul_blocked_with_default_correctness(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16" + ): + matmul = MatmulScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def assert_matmul_blocked_apply_config_correctness(M, + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization=False): + matmul = MatmulScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def assert_matmul_fine_grained_with_default_correctness(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16" + ): + + matmul = MatmulFineGrainScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + + +def assert_matmul_fine_grained_apply_config_correctness(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): + + matmul = MatmulFineGrainScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + + +def assert_matmul_weight_propagation_with_default_correctness(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16" + ): + + matmul = MatmulWeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(B.cpu()).cuda() + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + +def assert_matmul_weight_propagation_apply_config_correctness(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, + ): + + matmul = MatmulWeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(B.cpu()).cuda() + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) + torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + + +def test_matmul_blocked(): + # Default + assert_matmul_blocked_with_default_correctness(1024, 1024, 1024) + # Pipeline + assert_matmul_blocked_apply_config_correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_blocked_apply_config_correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_blocked_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) + +def test_matmul_fine_grained(): + # Default + assert_matmul_fine_grained_with_default_correctness(1024, 1024, 1024) + # Pipeline + assert_matmul_fine_grained_apply_config_correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_fine_grained_apply_config_correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_fine_grained_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) + +def test_matmul_weight_propagation(): + # Default + assert_matmul_weight_propagation_with_default_correctness(1024, 1024, 1024) + # Pipeline + assert_matmul_weight_propagation_apply_config_correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_weight_propagation_apply_config_correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_weight_propagation_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index f8217157a..27af4bd54 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -65,34 +65,36 @@ def main( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - B_local = T.alloc_fragment([8], storage_dtype, "local") - B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") + B_local = T.alloc_local([8], storage_dtype) + B_dequantize_local = T.alloc_local([16], dtypeAB) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) for i, j in T.Parallel(block_N, block_K // num_elems_per_byte): B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j] for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): - for t in T.thread_binding(0, threads, thread="threadIdx.x"): - for v in T.vectorized(0, 4): - vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte) - vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte) - B_local[v] = B_shared[vi, vj] - for v in T.serial(0, 8): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( - num_bits, - B_local[v // 2], - v % 2, - dtype=dtypeAB, - ) - for v in T.vectorized(0, 8): - vi = (i * threads * 8 + t * 8 + v) // (block_K) - vj = (i * threads * 8 + t * 8 + v) % (block_K) - B_dequantize_shared[vi, vj] = B_dequantize_local[v] + for v in T.vectorized(0, 4): + vi = (i * threads * 4 + tx * 4 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 4 + tx * 4 + v) % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, 8): + B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( + num_bits, + B_local[v // 2], + v % 2, + dtype=dtypeAB, + ) + for v in T.vectorized(0, 8): + vi = (i * threads * 8 + tx * 8 + v) // (block_K) + vj = (i * threads * 8 + tx * 8 + v) % (block_K) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C[by * block_M, bx * block_N]) @@ -125,14 +127,12 @@ def run_gemm( num_stages, num_threads, ) - print(program) mod, params = TL.lower(program) mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) out = mod.run_once() - - print(f"output is {out}") + assert out is not None def ref_program(A, qB): import torch From 730d13ea17530d720c95ffc4c4550cce94416bf5 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 28 Sep 2024 07:46:55 +0000 Subject: [PATCH 2/8] remove un-implemented code. --- .../tilelang/dequantize/matmul_weight_only.py | 110 ------------- .../test_general_matmul_tilelang_kernel.py | 147 +++++++++--------- 2 files changed, 75 insertions(+), 182 deletions(-) delete mode 100644 bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py deleted file mode 100644 index 0bb0e3ce2..000000000 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from bitblas import tvm as tvm -from tvm import DataType -import tvm.tl.language as T - -from bitblas.tl.utils import ( - get_mma_micro_size, - make_swizzle_layout, -) - -from bitblas.tl.macro_generator import ( - TensorCoreIntrinEmitter, - TensorCoreIntrinEmitterWithLadderTransform, -) - -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, - _tir_u8_to_f8_e4m3_to_f16, - _tir_packed_to_unsigned_convert_with_zeros, -) - -from bitblas.ops.operator import TransformKind - -# TODO(lei): Implement A General Matmul Emitter for Dequantize - -def matmul_blocked_weight_only( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - # Tile Related Params - block_M=64, - block_N=64, - block_K=32, - num_stages=2, - threads=128, - enable_rasterization=False, # Enhance L2 Locality -): - num_elems_per_byte = 8 // bit - A_shape = (M, K) - B_shape = (N, K // num_elems_per_byte) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K // num_elems_per_byte) - B_dequantize_shared_shape = (block_N, block_K) - - import tvm.tl.language as T - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - B_local = T.alloc_fragment([8], storage_dtype, "local") - B_dequantize_local = T.alloc_fragment([16], in_dtype, "local") - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - tx = T.thread_binding(0, threads, thread="threadIdx.x") - - if enable_rasterization: - # rasterization factor - T.use_swizzle(10) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, k * block_K], A_shared) - - for i, j in T.Parallel(block_N, block_K // num_elems_per_byte): - B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j] - - for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): - for v in T.vectorized(0, 4): - vi = (i * threads * 4 + tx * 4 + v) // (block_K // num_elems_per_byte) - vj = (i * threads * 4 + tx * 4 + v) % (block_K // num_elems_per_byte) - B_local[v] = B_shared[vi, vj] - for v in T.serial(0, 8): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( - bit, - B_local[v // 2], - v % 2, - dtype=in_dtype, - ) - for v in T.vectorized(0, 8): - vi = (i * threads * 8 + tx * 8 + v) // (block_K) - vj = (i * threads * 8 + tx * 8 + v) % (block_K) - B_dequantize_shared[vi, vj] = B_dequantize_local[v] - - T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main \ No newline at end of file diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 2ca273560..2890af3af 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -17,14 +17,13 @@ def assert_matmul_blocked_with_default_correctness(M, - N, - K, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16" - ): + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16"): matmul = MatmulScheduler( M=M, N=N, @@ -35,7 +34,7 @@ def assert_matmul_blocked_with_default_correctness(M, dtypeC=dtypeC, accum_dtype=accum_dtype, ).with_default_config() - + mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() @@ -61,19 +60,19 @@ def assert_matmul_blocked_with_default_correctness(M, def assert_matmul_blocked_apply_config_correctness(M, - N, - K, - block_M=64, - block_N=64, - block_K=32, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16", - num_stages=2, - threads=128, - enable_rasterization=False): + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization=False): matmul = MatmulScheduler( M=M, N=N, @@ -91,7 +90,7 @@ def assert_matmul_blocked_apply_config_correctness(M, threads=threads, enable_rasterization=enable_rasterization, ) - + mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() @@ -117,14 +116,13 @@ def assert_matmul_blocked_apply_config_correctness(M, def assert_matmul_fine_grained_with_default_correctness(M, - N, - K, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16" - ): + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16"): matmul = MatmulFineGrainScheduler( M=M, @@ -163,21 +161,22 @@ def assert_matmul_fine_grained_with_default_correctness(M, torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) -def assert_matmul_fine_grained_apply_config_correctness(M, - N, - K, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16", - block_row_warps=1, - block_col_warps=1, - warp_row_tiles=16, - warp_col_tiles=16, - chunk=32, - num_stages=2, - enable_rasterization=False, +def assert_matmul_fine_grained_apply_config_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, ): matmul = MatmulFineGrainScheduler( @@ -198,7 +197,6 @@ def assert_matmul_fine_grained_apply_config_correctness(M, num_stages=num_stages, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() @@ -225,14 +223,13 @@ def assert_matmul_fine_grained_apply_config_correctness(M, def assert_matmul_weight_propagation_with_default_correctness(M, - N, - K, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16" - ): + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16"): matmul = MatmulWeightPropagationScheduler( M=M, @@ -281,22 +278,24 @@ def assert_matmul_weight_propagation_with_default_correctness(M, print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) -def assert_matmul_weight_propagation_apply_config_correctness(M, - N, - K, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16", - block_row_warps=1, - block_col_warps=1, - warp_row_tiles=16, - warp_col_tiles=16, - chunk=32, - num_stages=2, - enable_rasterization=False, - ): + +def assert_matmul_weight_propagation_apply_config_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): matmul = MatmulWeightPropagationScheduler( M=M, @@ -361,6 +360,7 @@ def test_matmul_blocked(): # L2 Cache assert_matmul_blocked_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) + def test_matmul_fine_grained(): # Default assert_matmul_fine_grained_with_default_correctness(1024, 1024, 1024) @@ -370,6 +370,7 @@ def test_matmul_fine_grained(): # L2 Cache assert_matmul_fine_grained_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) + def test_matmul_weight_propagation(): # Default assert_matmul_weight_propagation_with_default_correctness(1024, 1024, 1024) @@ -377,7 +378,9 @@ def test_matmul_weight_propagation(): assert_matmul_weight_propagation_apply_config_correctness(1024, 1024, 1024, num_stages=2) assert_matmul_weight_propagation_apply_config_correctness(1024, 1024, 1024, num_stages=1) # L2 Cache - assert_matmul_weight_propagation_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) + assert_matmul_weight_propagation_apply_config_correctness( + 1024, 1024, 1024, enable_rasterization=True) + if __name__ == "__main__": bitblas.testing.main() From 8047ee7a00f0e84f46fa96da88deab32541756a9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 28 Sep 2024 08:07:16 +0000 Subject: [PATCH 3/8] Implement BaseScheduler to wrap some related items. --- .../general_matmul/tilelang/dense/matmul.py | 177 ++++++++++++------ .../test_general_matmul_tilelang_scheduler.py | 38 ++++ 2 files changed, 162 insertions(+), 53 deletions(-) create mode 100644 testing/python/operators/test_general_matmul_tilelang_scheduler.py diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index f5ae7a648..3b677b4ad 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -2,8 +2,10 @@ # Licensed under the MIT License. from bitblas import tvm as tvm from tvm import DataType +from tvm import IRModule +from tvm.tir import PrimFunc import tvm.tl.language as T - +from typing import Union, Optional from bitblas.tl.utils import ( get_mma_micro_size, make_swizzle_layout, @@ -20,12 +22,40 @@ @dataclass -class MatmulScheduler: +class BaseScheduler: + + enable_simplify: bool = True + + @staticmethod + def Simplify(stmt: Union[PrimFunc, IRModule]): + if isinstance(stmt, PrimFunc): + return tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(stmt))["main"] + elif isinstance(stmt, IRModule): + return tvm.tir.transform.Simplify()(stmt) + else: + raise ValueError(f"Unsupported type: {type(stmt)}") + + def enable_simplify(self): + self.enable_simplify = True + return self + + def disable_simplify(self): + self.enable_simplify = False + return self + + def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]): + if self.enable_simplify: + return self.Simplify(stmt) + return stmt + + +@dataclass +class MatmulScheduler(BaseScheduler): # OP Related Config - M: int - N: int - K: int + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None trans_A: bool = False trans_B: bool = False dtypeAB: str = "float16" @@ -105,7 +135,7 @@ def main( T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) - return main + return self.maybe_simplify(main) def __post_init__(self): # Add Config Validation @@ -113,14 +143,14 @@ def __post_init__(self): @dataclass -class MatmulFineGrainScheduler: +class MatmulFineGrainScheduler(BaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. # Operation Configuration - M: int - N: int - K: int + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None dtypeAB: str = "float16" dtypeC: str = "float16" trans_A: bool = False @@ -157,14 +187,16 @@ def with_default_config(self): enable_rasterization=enable_rasterization, ) - def apply_config(self, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=32, - warp_col_tiles=32, - chunk=16, - num_stages=2, - enable_rasterization=False): + def apply_config( + self, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=16, + num_stages=2, + enable_rasterization=False, + ): M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B @@ -182,8 +214,12 @@ def apply_config(self, B_shape = (N, K) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) - C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, - micro_size_y) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) # GPU warp configuration for NVIDIA GPUs warp_size = 32 @@ -207,7 +243,8 @@ def apply_config(self, block_col_warps=block_col_warps, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, - chunk=chunk) + chunk=chunk, + ) # Define the main kernel using the generated configuration @T.prim_func @@ -288,9 +325,9 @@ def main( for i, j in T.Parallel(block_M, block_N): C[by * block_M + i, bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y] + i % micro_size_x, j % micro_size_y,] - return main + return self.maybe_simplify(main) def __post_init__(self): # Validate the matrix transpose settings @@ -301,14 +338,14 @@ def __post_init__(self): @dataclass -class MatmulWeightPropagationScheduler: +class MatmulWeightPropagationScheduler(BaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. # Operation Configuration - M: int - N: int - K: int + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None dtypeAB: str = "float16" dtypeC: str = "float16" trans_A: bool = False @@ -345,14 +382,16 @@ def with_default_config(self): enable_rasterization=enable_rasterization, ) - def apply_config(self, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=32, - warp_col_tiles=32, - chunk=16, - num_stages=2, - enable_rasterization=False): + def apply_config( + self, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=16, + num_stages=2, + enable_rasterization=False, + ): M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B @@ -377,10 +416,18 @@ def apply_config(self, A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) - B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, - micro_size_k) - C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, - micro_size_y) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) # GPU warp configuration for NVIDIA GPUs warp_size = 32 @@ -451,10 +498,14 @@ def main( A_shared[i, k] = A[by * block_M + i, ko * block_K + k] # Load B matrix into shared memory - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, - micro_size_y, micro_size_k): + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ): B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, jj, kk] + ko * (block_K // micro_size_k) + k, jj, kk,] # Perform the matrix multiplication on tensor core fragments for ki in T.serial(0, (block_K // micro_size_k)): @@ -489,9 +540,9 @@ def main( for i, j in T.Parallel(block_M, block_N): C[by * block_M + i, bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y] + i % micro_size_x, j % micro_size_y,] - return main + return self.maybe_simplify(main) def __post_init__(self): # Validate the matrix transpose settings @@ -583,7 +634,12 @@ def matmul_macro_tensorcore( B_shape = (N, K) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) - C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) warp_size = 32 # nvidia gpu warp size is 32 threads = warp_size * (block_row_warps * block_col_warps) @@ -602,7 +658,8 @@ def matmul_macro_tensorcore( block_col_warps=block_col_warps, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, - chunk=chunk) + chunk=chunk, + ) @T.prim_func def main( @@ -667,7 +724,7 @@ def main( for i, j in T.Parallel(block_M, block_N): C[by * block_M + i, bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y] + i % micro_size_x, j % micro_size_y,] return main @@ -707,8 +764,18 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) - B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k) - C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) warp_size = 32 # nvidia gpu warp size is 32 threads = warp_size * (block_row_warps * block_col_warps) @@ -762,10 +829,14 @@ def main( for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, - micro_size_y, micro_size_k): + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ): B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, jj, kk] + ko * (block_K // micro_size_k) + k, jj, kk,] for ki in T.serial(0, (block_K // micro_size_k)): @@ -796,6 +867,6 @@ def main( for i, j in T.Parallel(block_M, block_N): C[by * block_M + i, bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y] + i % micro_size_x, j % micro_size_y,] return main diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py new file mode 100644 index 000000000..26f823a97 --- /dev/null +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl +from tvm.ir import structural_equal +from bitblas.ops.general_matmul.tilelang.dense.matmul import ( + MatmulScheduler, +) + +def test_scheduler_simplify(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16"): + matmul = MatmulScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + accum_dtype=accum_dtype, + ).disable_simplify().with_default_config() + + simplified = MatmulScheduler.Simplify(matmul) + + is_equal = structural_equal(matmul, simplified) + + assert is_equal == False, "Simplify should not return the same schedule" + +if __name__ == "__main__": + bitblas.testing.main() From 64db0655683342ede824c4ca95d0e448479e2e5c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 28 Sep 2024 08:11:27 +0000 Subject: [PATCH 4/8] lint fix --- .../test_general_matmul_tilelang_scheduler.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index 26f823a97..c75d4872c 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -3,20 +3,19 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl from tvm.ir import structural_equal from bitblas.ops.general_matmul.tilelang.dense.matmul import ( - MatmulScheduler, -) - -def test_scheduler_simplify(M, - N, - K, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16"): + MatmulScheduler,) + + +def assert_scheduler_simplify(M, + N, + K, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16"): matmul = MatmulScheduler( M=M, N=N, @@ -27,12 +26,16 @@ def test_scheduler_simplify(M, dtypeC=dtypeC, accum_dtype=accum_dtype, ).disable_simplify().with_default_config() - + simplified = MatmulScheduler.Simplify(matmul) is_equal = structural_equal(matmul, simplified) - - assert is_equal == False, "Simplify should not return the same schedule" + assert is_equal is False, "Simplify should not return the same schedule" + + +def test_scheduler_simplify(): + assert_scheduler_simplify(128, 128, 128) + if __name__ == "__main__": bitblas.testing.main() From cef04a875e022445d6ad2b28ddd5e6b3ca939266 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 28 Sep 2024 09:13:19 +0000 Subject: [PATCH 5/8] test skip --- .../python/operators/test_general_matmul_tilelang_kernel.py | 4 ++-- .../operators/test_general_matmul_tilelang_scheduler.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 2890af3af..18115f450 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -276,7 +276,7 @@ def assert_matmul_weight_propagation_with_default_correctness(M, ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) print(C) print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) def assert_matmul_weight_propagation_apply_config_correctness( @@ -348,7 +348,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( # Get Reference Result ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) - torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) def test_matmul_blocked(): diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index c75d4872c..1e6bd6466 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -25,7 +25,7 @@ def assert_scheduler_simplify(M, dtypeAB=dtypeAB, dtypeC=dtypeC, accum_dtype=accum_dtype, - ).disable_simplify().with_default_config() + ).deactivate_simplify().with_default_config() simplified = MatmulScheduler.Simplify(matmul) From f1652e9841d4bbe903825bbbae85688442fc9a8c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 29 Sep 2024 06:57:01 +0000 Subject: [PATCH 6/8] Refactor tilelang dequantize module and add matmul_blocked_weight_only function --- 3rdparty/tvm | 2 +- bitblas/builder/lib_generator/__init__.py | 20 +- bitblas/builder/wrapper/__init__.py | 1 + bitblas/builder/wrapper/base.py | 16 ++ bitblas/builder/wrapper/tir.py | 37 +--- bitblas/builder/wrapper/tl.py | 197 ++++++++++++++++++ bitblas/cache/operator.py | 4 +- bitblas/gpu/matmul_mma.py | 2 +- bitblas/gpu/matmul_mma_dequantize.py | 2 +- bitblas/ops/base_scheduler.py | 45 ++++ bitblas/ops/common.py | 20 ++ bitblas/ops/general_matmul/__init__.py | 32 ++- bitblas/ops/general_matmul/cuda/__init__.py | 3 +- .../ops/general_matmul/tilelang/__init__.py | 2 - .../general_matmul/tilelang/dense/__init__.py | 52 +++++ .../general_matmul/tilelang/dense/matmul.py | 184 +++++++--------- .../tirscript/matmul_dequantize_impl.py | 2 +- .../general_matmul/tirscript/matmul_impl.py | 2 +- bitblas/ops/general_matmul_splitk.py | 2 +- .../ops/impl/batch_matmul_dequantize_impl.py | 2 +- bitblas/ops/impl/batch_matmul_impl.py | 2 +- bitblas/ops/impl/matmul_dequantize_impl.py | 2 +- .../ops/impl/matmul_dequantize_splitk_impl.py | 2 +- bitblas/ops/impl/matmul_impl.py | 2 +- bitblas/ops/impl/matmul_splitk_impl.py | 2 +- bitblas/ops/impl/param_permutate_impl.py | 2 +- bitblas/ops/ladder_permutate/__init__.py | 2 +- bitblas/ops/operator.py | 189 ++++++++++------- bitblas/tl/macro_generator.py | 2 +- bitblas/utils/post_process.py | 3 +- bitblas/utils/rtmod_analysis.py | 74 ++++++- docs/ExtendOperatorsWithDSL.md | 2 +- .../builder/test_backend_tir_builder.py | 5 +- .../test_general_matmul_ops_backend_tl.py | 50 +++++ .../test_general_matmul_tilelang_impl.py | 36 ++-- .../test_general_matmul_tilelang_kernel.py | 84 ++++---- .../test_general_matmul_tilelang_scheduler.py | 8 +- .../tilelang/test_tilelang_dequantize_gemm.py | 60 +++--- .../test_tilelang_dyanmic_symbolic.py | 92 ++++---- testing/python/tilelang/test_tilelang_gemm.py | 22 +- .../tilelang/test_tilelang_macro_gemm.py | 144 ++++++------- 41 files changed, 937 insertions(+), 475 deletions(-) create mode 100644 bitblas/builder/wrapper/tl.py create mode 100644 bitblas/ops/base_scheduler.py create mode 100644 bitblas/ops/common.py create mode 100644 testing/python/operators/test_general_matmul_ops_backend_tl.py diff --git a/3rdparty/tvm b/3rdparty/tvm index c115bfd4c..d0c06c764 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c115bfd4cc9c5257b0b7b3046571d5ab60db39d3 +Subproject commit d0c06c7641956a3bd9ab1174ed05a1aa2a624d2a diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 64eaee9e8..46336e0c2 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -4,6 +4,7 @@ from bitblas.base.arch import TileDevice import ctypes import os +import os.path as osp import tempfile import subprocess import logging @@ -26,7 +27,7 @@ def update_lib_code(self, lib_code: str): def load_lib(self): return ctypes.CDLL(self.libpath) - def compile_lib(self, timeout: float = None): + def compile_lib(self, timeout: float = None, with_tl: bool = False): arch = self.arch src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) compute_version = arch.compute_capability @@ -45,9 +46,22 @@ def compile_lib(self, timeout: float = None): "-lcuda", "-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}", - "-o", - libpath, ] + if with_tl: + tvm_root = osp.join(osp.dirname(__file__), "../../../3rdparty/tvm") + tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) + if "TL_CUTLASS_PATH" in os.environ: + cutlass_path = os.environ["TL_CUTLASS_PATH"] + else: + cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include")) + + command += [ + "-I" + tl_template_path, + "-I" + cutlass_path, + ] + command += ["-diag-suppress=20013"] + command += ["-o", libpath] + src.write(self.lib_code) src.flush() try: diff --git a/bitblas/builder/wrapper/__init__.py b/bitblas/builder/wrapper/__init__.py index c864f7a4b..9f089c13c 100644 --- a/bitblas/builder/wrapper/__init__.py +++ b/bitblas/builder/wrapper/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .tir import TIRWrapper # noqa: F401 +from .tl import TLWrapper # noqa: F401 diff --git a/bitblas/builder/wrapper/base.py b/bitblas/builder/wrapper/base.py index 1705af2cc..c63b9ee26 100644 --- a/bitblas/builder/wrapper/base.py +++ b/bitblas/builder/wrapper/base.py @@ -2,6 +2,22 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod +PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ + cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {}); +""" + +PREDEF_INIT_FUNC = """ +extern "C" void init() {{ + {} +}} +""" + +PREDEF_HOST_FUNC = """ +extern "C" void call({}) {{ +{} +}} +""" + class BaseWrapper(ABC): diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index f39c7cfab..b57981515 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -9,26 +9,11 @@ import re import logging -from .base import BaseWrapper +from .base import (BaseWrapper, PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY, PREDEF_INIT_FUNC, + PREDEF_HOST_FUNC) logger = logging.getLogger(__name__) -PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ - cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {}); -""" - -PREDEF_INIT_FUNC = """ -extern "C" void init() {{ - {} -}} -""" - -PREDEF_HOST_FUNC = """ -extern "C" void call({}) {{ -{} -}} -""" - class TIRCUDASourceWrapper(object): _TYPE_MAP = { @@ -48,8 +33,8 @@ class TIRCUDASourceWrapper(object): "uchar": "uint8_t", } - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - self.mod = optimized_mod + def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): + self.mod = scheduled_ir_module self.arch = arch self.source = source self.function_name: Optional[str] = None @@ -190,8 +175,8 @@ def prim_func(self): class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper): - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - super().__init__(optimized_mod, source, arch) + def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): + super().__init__(scheduled_ir_module, source, arch) def get_cuda_init_func(self): # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory @@ -387,16 +372,16 @@ class TIRWrapper(BaseWrapper): def __init__(self, arch: TileDevice): super().__init__() - self.optimized_mod = None + self.scheduled_ir_module = None self.arch = arch self.lib = None - def assign_optimized_module(self, optimized_mod: IRModule): - self.optimized_mod = optimized_mod + def assign_optimized_module(self, scheduled_ir_module: IRModule): + self.scheduled_ir_module = scheduled_ir_module # Get Scheduled Rt Module and return source to be compiled def wrap(self, c_source: str, is_dynamic: bool = False): - assert self.optimized_mod is not None, "Please assign optimized module first." + assert self.scheduled_ir_module is not None, "Please assign optimized module first." wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic - wrapper = wrapper_class(self.optimized_mod, c_source, self.arch) + wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.arch) return wrapper.lib_code diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py new file mode 100644 index 000000000..cdd19a172 --- /dev/null +++ b/bitblas/builder/wrapper/tl.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm +from typing import Optional, List, Dict, Union +from tvm import IRModule +from bitblas.base.arch import TileDevice +from bitblas.utils import match_global_kernel +from bitblas.utils.rtmod_analysis import get_annotated_device_mod +import re +import logging + +from .base import ( + BaseWrapper, + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY, + PREDEF_INIT_FUNC, + PREDEF_HOST_FUNC +) + +logger = logging.getLogger(__name__) + + +class TLCUDASourceWrapper(object): + _TYPE_MAP = { + "float32": "float", + "float16": "half_t", + "bfloat16": "__nv_bfloat16", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uchar": "uint8_t", + } + + def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): + self.mod = scheduled_ir_module + self.arch = arch + self.source = source + self.function_name: Optional[str] = None + self.dynamic_smem_buf: Optional[int] = None + self.block_info: Union[List[int], Dict] = [1, 1, 1] + self.grid_info: Union[List[int], Dict] = [1, 1, 1] + self.parse_source_information() + self.srcpath: Optional[str] = None + self.libpath: Optional[str] = None + self.lib_code: Optional[str] = self.update_lib_code(source) + + def parse_source_information(self): + device_mod = get_annotated_device_mod(self.mod, self.arch.target, backend="tl") + assert (len(device_mod.functions) == 1 + ), "Only support one function in the module for static shape kernel." + for g_var, func in device_mod.functions.items(): + self.function_name = g_var.name_hint + attrs = func.attrs + if "dyn_shared_memory_buf" in attrs: + self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + self.block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + self.grid_info["xyz".index(tag[-1])] = extent + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set = set() + for param in prim_func.params: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var): + dynamic_symbolic_set.add(dim.name) + return dynamic_symbolic_set + + def get_cuda_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + if self.dynamic_smem_buf is not None: + call_str = ( + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, + self.dynamic_smem_buf)) + # Format the initialization function using the call_str + init_funcs = PREDEF_INIT_FUNC.format(call_str) + return init_funcs + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Find the index of the global kernel function in the code + index = match_global_kernel(code) + # Extract the declaration of the function starting from the found index + declaration = code[index:].split(";")[0] + + function_name = self.function_name + # Get the CUDA initialization function + init_func = self.get_cuda_init_func() + + # Locate the opening brace of the function to insert arguments + index = code.index("{", index) + function_args = [] + # Populate the function arguments from the primary function's parameters and buffers + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + # Add dynamic symbolic parameters as integers to the function arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s, function_args): + # Extract the function call arguments matching the function definition + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(declaration, function_args)) + block_info, grid_info = self.block_info, self.grid_info + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + # Prepare the block and grid dimensions for the CUDA kernel launch + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) + # Determine the shared memory size, defaulting to 0 if not specified + smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf + # Format the CUDA kernel launch string + if len(dynamic_symbolic_set) != 0: + call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) + else: + call_str = "" + call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + smem_str, call_args) + # Create the host function wrapper for the CUDA kernel + host_func = PREDEF_HOST_FUNC.format(def_args, call_str) + # Combine the source, initialization function, and host function to form the complete library code + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + if len(self.mod.get_global_vars()) == 1: + return self.mod[self.mod.get_global_vars()[0]] + elif "main" in self.mod: + return self.mod["main"] + else: + raise ValueError("Unable to determine primary function.") + + +class TLWrapper(BaseWrapper): + + def __init__(self, arch: TileDevice): + super().__init__() + self.scheduled_ir_module = None + self.arch = arch + self.lib = None + + def assign_optimized_module(self, scheduled_ir_module: IRModule): + self.scheduled_ir_module = scheduled_ir_module + + # Get Scheduled Rt Module and return source to be compiled + def wrap(self, c_source: str, is_dynamic: bool = False): + assert is_dynamic is False, "Dynamic kernel is not supported in TLWrapper." + assert self.scheduled_ir_module is not None, "Please assign optimized module first." + wrapper_class = TLCUDASourceWrapper + wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.arch) + return wrapper.lib_code diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 0e7ecaa54..0dbbdf96b 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -108,8 +108,8 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): # For writing optimized.py file optimized_file_path = os.path.join(config_path, "optimized.py") with open(optimized_file_path, "w") as optimized_file: - if op_inst.optimized_mod is not None: - optimized_file.write(op_inst.optimized_mod.script(show_meta=False)) + if op_inst.scheduled_ir_module is not None: + optimized_file.write(op_inst.scheduled_ir_module.script(show_meta=False)) if op_inst.libpath is not None: # copy lib name to the same directory as the artifact srcpath = op_inst.srcpath diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 591d6ced9..5ed6f0723 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -8,7 +8,7 @@ from tvm import tir, DataType from tvm.target import Target -from ..ops.operator import TransformKind +from ..ops.common import TransformKind from ..base.roller import Hint from ..base.roller.rasterization import NoRasterization from ..base import analysis diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 7dfbd2408..9932e69fc 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -9,7 +9,7 @@ from tvm import tir, DataType from tvm.target import Target -from ..ops.operator import TransformKind +from ..ops.common import TransformKind from ..base.roller.hint import Hint, IntrinInfo from ..base.roller.rasterization import NoRasterization from ..base import analysis diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py new file mode 100644 index 000000000..3b8291b41 --- /dev/null +++ b/bitblas/ops/base_scheduler.py @@ -0,0 +1,45 @@ +from tvm import IRModule +from tvm.tir import PrimFunc +from typing import Union +from dataclasses import dataclass, field +from tvm.tir.transform import Simplify +from abc import ABC, abstractmethod + +@dataclass +class BaseScheduler(ABC): + + _enable_simplify: bool = field(default=True, init=False, repr=False) + + @staticmethod + def Simplify(stmt: Union[PrimFunc, IRModule]): + if isinstance(stmt, PrimFunc): + return Simplify()(IRModule.from_expr(stmt))["main"] + elif isinstance(stmt, IRModule): + return Simplify()(stmt) + else: + raise ValueError(f"Unsupported type: {type(stmt)}") + + def activate_simplify(self): + self._enable_simplify = True + return self + + def deactivate_simplify(self): + self._enable_simplify = False + return self + + def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]): + if self._enable_simplify: + return self.Simplify(stmt) + return stmt + + @abstractmethod + def with_default_config(self): + pass + + @abstractmethod + def apply_config( + self, + *args, + **kwargs, + ): + pass diff --git a/bitblas/ops/common.py b/bitblas/ops/common.py new file mode 100644 index 000000000..1b1b77fcb --- /dev/null +++ b/bitblas/ops/common.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from enum import IntEnum + +class OptimizeStrategy(IntEnum): + SingleBatchDecodeOnly = 0 + ContigousBatching = 1 + + +class TransformKind(IntEnum): + NonTransform = 0 + InterWarpTransform = 1 + IntraWarpTransform = 2 + LDMatrixTransform = 3 + + +class BackendKind(IntEnum): + TIR = 0 + TileLang = 1 diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index c26b9c7a9..b7b884443 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -4,13 +4,14 @@ from tvm.target import Target import operator from functools import reduce -from enum import IntEnum from bitblas.base.arch.cuda import CUDA from bitblas.base.roller.hint import Hint from typing import Any, Literal, Optional, Tuple, Union -from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU, BaseKernelNameGenerator +from ..operator import OperatorConfig, Operator, OPExecutorCPU, BaseKernelNameGenerator +from ..common import TransformKind, OptimizeStrategy from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation from .tirscript.matmul_impl import select_implementation as consistent_implementation +from .tilelang.dense import select_scheduler as consistent_scheduler from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass @@ -48,11 +49,6 @@ def is_native_compute(A_dtype, W_dtype) -> bool: """ -class OptimizeStrategy(IntEnum): - SingleBatchDecodeOnly = 0 - ContigousBatching = 1 - - @dataclass(frozen=True) class MatmulConfig(OperatorConfig): M: Union[int, Tuple[int]] = None @@ -357,8 +353,7 @@ def __init__( self.source_format = source_format self.bit = bit - self.backend = backend - super().__init__(name, config, target) + super().__init__(name, config, target, backend) if source_format == "int" and self.with_zeros: logger.warning( @@ -381,7 +376,7 @@ def dispatch_tir(self, if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} - self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( + self.ir_module["main"] = self.ir_module["main"].with_attrs( {"opt_shapes": self.dynamic_range}) else: self.dynamic_range = None @@ -577,6 +572,23 @@ def _select_implementation(self): propagate_b=self.propagate_b, ) + def _select_scheduler(self): + if is_native_compute(self.A_dtype, self.W_dtype): + return consistent_scheduler( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + else: + raise ValueError("Currently only support native compute for scheduler") + def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) diff --git a/bitblas/ops/general_matmul/cuda/__init__.py b/bitblas/ops/general_matmul/cuda/__init__.py index a0366abd3..b57beb358 100644 --- a/bitblas/ops/general_matmul/cuda/__init__.py +++ b/bitblas/ops/general_matmul/cuda/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# TODO: Not Implemented Yet -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from bitblas.base import TileDevice from .template import i4_scale_template_source diff --git a/bitblas/ops/general_matmul/tilelang/__init__.py b/bitblas/ops/general_matmul/tilelang/__init__.py index 92956855c..59e481eb9 100644 --- a/bitblas/ops/general_matmul/tilelang/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/__init__.py @@ -1,4 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -# TODO: Not Implemented Yet diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 23cda34db..2a929355c 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -12,3 +12,55 @@ MatmulFineGrainScheduler, # noqa: F401 MatmulWeightPropagationScheduler, # noqa: F401 ) + +from bitblas.ops.common import TransformKind +from typing import Union + + +def parse_layout(layout: str): + if len(layout) != 2 or layout[0] not in "nt" or layout[1] not in "nt": + raise ValueError(f"Invalid layout: {layout}") + + trans_A = layout[0] == 't' + trans_B = layout[1] == 't' + + return trans_A, trans_B + + +def is_non_transform_kind(kind) -> bool: + return kind == TransformKind.NonTransform + + +def select_scheduler( + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, +): + if isinstance(propagate_a, int): + propagate_a = TransformKind(propagate_a) + if isinstance(propagate_b, int): + propagate_b = TransformKind(propagate_b) + if with_bias: + raise NotImplementedError + + trans_A, trans_B = parse_layout(layout) + if is_non_transform_kind(propagate_a) and is_non_transform_kind(propagate_b): + return MatmulScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ) + else: + raise ValueError(f"Unsupported transform kind: {propagate_a}, {propagate_b}") diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 3b677b4ad..1c28ff695 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -2,10 +2,8 @@ # Licensed under the MIT License. from bitblas import tvm as tvm from tvm import DataType -from tvm import IRModule -from tvm.tir import PrimFunc import tvm.tl.language as T -from typing import Union, Optional +from typing import Optional from bitblas.tl.utils import ( get_mma_micro_size, make_swizzle_layout, @@ -15,40 +13,12 @@ TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform, ) - -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind +from bitblas.ops.base_scheduler import BaseScheduler from dataclasses import dataclass -@dataclass -class BaseScheduler: - - enable_simplify: bool = True - - @staticmethod - def Simplify(stmt: Union[PrimFunc, IRModule]): - if isinstance(stmt, PrimFunc): - return tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(stmt))["main"] - elif isinstance(stmt, IRModule): - return tvm.tir.transform.Simplify()(stmt) - else: - raise ValueError(f"Unsupported type: {type(stmt)}") - - def enable_simplify(self): - self.enable_simplify = True - return self - - def disable_simplify(self): - self.enable_simplify = False - return self - - def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]): - if self.enable_simplify: - return self.Simplify(stmt) - return stmt - - @dataclass class MatmulScheduler(BaseScheduler): @@ -58,8 +28,8 @@ class MatmulScheduler(BaseScheduler): K: Optional[int] = None trans_A: bool = False trans_B: bool = False - dtypeAB: str = "float16" - dtypeC: str = "float16" + in_dtype: str = "float16" + out_dtype: str = "float16" accum_dtype: str = "float16" # Default Tile Related Params @@ -99,7 +69,7 @@ def apply_config( ): M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B - dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) @@ -108,14 +78,14 @@ def apply_config( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) if enable_rasterization: @@ -151,8 +121,8 @@ class MatmulFineGrainScheduler(BaseScheduler): M: Optional[int] = None N: Optional[int] = None K: Optional[int] = None - dtypeAB: str = "float16" - dtypeC: str = "float16" + in_dtype: str = "float16" + out_dtype: str = "float16" trans_A: bool = False trans_B: bool = True accum_dtype: str = "float16" @@ -200,10 +170,10 @@ def apply_config( M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B - dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype # Calculate the micro size per warp using a helper function - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles @@ -234,8 +204,8 @@ def apply_config( # Configure the tensor core intrinsic emitter mma_emitter = TensorCoreIntrinEmitter( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=trans_A, b_transposed=trans_B, @@ -249,20 +219,20 @@ def apply_config( # Define the main kernel using the generated configuration @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): # Allocate shared memory and local fragments - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) # Thread-level parallelism for Tensor Cores @@ -346,8 +316,8 @@ class MatmulWeightPropagationScheduler(BaseScheduler): M: Optional[int] = None N: Optional[int] = None K: Optional[int] = None - dtypeAB: str = "float16" - dtypeC: str = "float16" + in_dtype: str = "float16" + out_dtype: str = "float16" trans_A: bool = False trans_B: bool = True accum_dtype: str = "float16" @@ -395,22 +365,22 @@ def apply_config( M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B - dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype # Calculate the micro size per warp using a helper function - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if dtypeAB == "float16" else 16 + pad_factor = 8 if in_dtype == "float16" else 16 - can_swizzle_a = block_K * DataType(dtypeAB).bits == 512 + can_swizzle_a = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not can_swizzle_a - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) # Define the shapes of matrices and shared memory buffers A_shape = (M, K) @@ -442,8 +412,8 @@ def apply_config( # Configure the tensor core intrinsic emitter mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=trans_A, b_transposed=trans_B, @@ -458,20 +428,20 @@ def apply_config( # Define the main kernel using the generated configuration @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): # Grid and thread configuration for CUDA kernel with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): # Allocate shared memory and local fragments - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) # Thread-level parallelism for Tensor Cores @@ -561,8 +531,8 @@ def matmul_blocked( block_K=32, trans_A=False, trans_B=False, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16", num_stages=2, threads=128, @@ -575,13 +545,13 @@ def matmul_blocked( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) if enable_rasterization: @@ -608,8 +578,8 @@ def matmul_macro_tensorcore( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, trans_A, trans_B, accum_dtype, @@ -628,7 +598,7 @@ def matmul_macro_tensorcore( block_N = block_col_warps * warp_col_tiles block_K = chunk - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) A_shape = (M, K) B_shape = (N, K) @@ -649,8 +619,8 @@ def matmul_macro_tensorcore( shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory mma_emitter = TensorCoreIntrinEmitter( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -663,17 +633,17 @@ def matmul_macro_tensorcore( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -733,8 +703,8 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, trans_A, trans_B, accum_dtype, @@ -754,12 +724,12 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( block_K = chunk # TODO(lei): Can be generalized to analyzed from bank size - pad_factor = 8 if dtypeAB == "float16" else 16 + pad_factor = 8 if in_dtype == "float16" else 16 - can_swizzle_a = block_K * DataType(dtypeAB).bits == 512 + can_swizzle_a = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not can_swizzle_a - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) A_shape = (M, K) B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) @@ -785,8 +755,8 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=trans_A, b_transposed=trans_B, @@ -800,17 +770,17 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py index a86f6469a..0cd17feb3 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm import te, DataType from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import ( _tir_packed_int_to_int_convert, diff --git a/bitblas/ops/general_matmul/tirscript/matmul_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_impl.py index 6a3e1de2d..911c8ea76 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm import te from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from typing import Union diff --git a/bitblas/ops/general_matmul_splitk.py b/bitblas/ops/general_matmul_splitk.py index 39671432a..d16674564 100644 --- a/bitblas/ops/general_matmul_splitk.py +++ b/bitblas/ops/general_matmul_splitk.py @@ -4,7 +4,7 @@ import operator from functools import reduce from typing import Any, Optional, Union -from .operator import TransformKind +from .common import TransformKind from .impl.matmul_splitk_impl import select_implementation as consistent_implementation from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation from dataclasses import dataclass diff --git a/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/bitblas/ops/impl/batch_matmul_dequantize_impl.py index 6303f4bf8..6a5f740a0 100644 --- a/bitblas/ops/impl/batch_matmul_dequantize_impl.py +++ b/bitblas/ops/impl/batch_matmul_dequantize_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm import te, DataType from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 3904f36e6..064dd061f 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -3,7 +3,7 @@ # pre-transformed tir expression of matmul from bitblas import tvm from tvm import te -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from .base import TIRScriptEmitter, TIRScriptSelector diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 55d672097..ec450610a 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm import te, DataType from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import ( _tir_packed_int_to_int_convert, diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index 657b45a42..bb63b10e5 100644 --- a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm import te, DataType from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, diff --git a/bitblas/ops/impl/matmul_impl.py b/bitblas/ops/impl/matmul_impl.py index db4f4d3f3..9c9cc2e1e 100644 --- a/bitblas/ops/impl/matmul_impl.py +++ b/bitblas/ops/impl/matmul_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm from tvm import te from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind def matmul_nn( diff --git a/bitblas/ops/impl/matmul_splitk_impl.py b/bitblas/ops/impl/matmul_splitk_impl.py index c314fa6ca..3a825ac4f 100644 --- a/bitblas/ops/impl/matmul_splitk_impl.py +++ b/bitblas/ops/impl/matmul_splitk_impl.py @@ -3,7 +3,7 @@ # pre-transformed tir expression of matmul from bitblas import tvm from tvm import te -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind def matmul_nt( diff --git a/bitblas/ops/impl/param_permutate_impl.py b/bitblas/ops/impl/param_permutate_impl.py index 4ecb17709..8f9ce04ff 100644 --- a/bitblas/ops/impl/param_permutate_impl.py +++ b/bitblas/ops/impl/param_permutate_impl.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas.gpu.matmul_analysis import get_propagate_map -from ..operator import TransformKind +from ..common import TransformKind from typing import Literal from tvm import te, IRModule diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index 65ad06679..c3406f6a0 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -38,7 +38,7 @@ def __init__( target = self.target if target.kind.name == "cuda": - self.optimized_mod = self.apply_default_schedule(self.prim_func_mod, target) + self.scheduled_ir_module = self.apply_default_schedule(self.ir_module, target) if enable_tuning: self.hardware_aware_finetune() if not from_database: diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index c8a9cb08a..eb02fdf70 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -2,22 +2,24 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod from bitblas import tvm +from tvm import tl from tvm import IRModule +from tvm.runtime.module import Module from tvm.target import Target from tvm.tir import PrimFunc from tvm.contrib.dlpack import to_pytorch_func import bitblas import ctypes -from typing import List, Dict, Any, Optional, Tuple +from typing import (List, Dict, Any, Optional, Tuple, Literal, Callable) import numpy as np from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy -from bitblas.base.arch import get_arch +from bitblas.ops.base_scheduler import BaseScheduler +from bitblas.base.arch import get_arch, TileDevice from bitblas.base.roller.hint import Hint -from bitblas.builder.wrapper import TIRWrapper +from bitblas.builder.wrapper import TIRWrapper, TLWrapper from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass -from enum import IntEnum import logging logger = logging.getLogger(__name__) @@ -33,13 +35,6 @@ "Please perform hardware-aware tuning manually.") -class TransformKind(IntEnum): - NonTransform = 0 - InterWarpTransform = 1 - IntraWarpTransform = 2 - LDMatrixTransform = 3 - - @dataclass(frozen=True) class OperatorConfig: """Base class for operator configurations. Used for typing.""" @@ -64,33 +59,53 @@ def generate(self, hint: Hint = None) -> str: pass -class Operator(ABC): +class Operator(object): - def __init__(self, name, config: OperatorConfig, target: Target = None): + def __init__(self, + name, + config: OperatorConfig, + target: Target = None, + backend: Literal["tir", "tl"] = "tir"): if isinstance(target, str): target = Target(target) self.name = name self.config = config self.target = target - self.prim_func_mod = self._select_implementation() - self.optimized_mod = None - self.rt_mod = None - self.time_evaluator = None - self.arch = get_arch(target) if target else None - self.dynamic_range = None - self.pass_context: Dict = {} - self.num_args = len(self.prim_func.params) - self.num_output_args: int = ( - 1 # todo(lei): should be analyzed from the prim_func. - ) + self.backend = backend + + self.ir_module: Optional[IRModule] = ( + self._select_implementation() if self.is_tir_backend() else None) + self.scheduler: Optional[BaseScheduler] = ( + self._select_scheduler() if self.is_tilelang_backend() else None) + + self.scheduled_ir_module: Optional[IRModule] = None + self.rt_mod: Optional[Module] = None + self.time_evaluator: Optional[Callable] = None + self.dynamic_range: Optional[Dict] = None + self.arch: Optional[TileDevice] = get_arch(target) if target else None + self.pass_context: Optional[Dict] = None + self.kernel_name_generator: Optional[BaseKernelNameGenerator] = ( self.get_kernel_name_generator()) self.lib_generator = LibraryGenerator(self.arch) - self.wrapper = TIRWrapper(self.arch) - self.lib = None + + if self.is_tir_backend(): + self.wrapper = TIRWrapper(self.arch) + elif self.is_tilelang_backend(): + self.wrapper = TLWrapper(self.arch) + else: + raise ValueError(f"Unsupported backend: {self.backend}") + + self.lib: Optional[ctypes.CDLL] = None + + def is_tir_backend(self): + return self.backend == "tir" + + def is_tilelang_backend(self): + return self.backend == "tl" def get_kernel_name_generator(self) -> Optional[BaseKernelNameGenerator]: - return None + raise NotImplementedError def get_source(self, target: Optional[Target] = None, kenrel_only=False) -> str: if target is None: @@ -123,7 +138,7 @@ def _build_runtime_module(self, target: Target): # Check if the platform is CUDA and we have an optimized function if self.arch.platform == "CUDA": - if self.optimized_mod is None: + if self.scheduled_ir_module is None: return None @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) @@ -131,12 +146,22 @@ def tvm_callback_cuda_postproc(code, _): return self.post_process(code) try: - with tvm.transform.PassContext(config={ - "tir.use_async_copy": True, - "tir.disable_cse_tir": True, - **self.pass_context - }): - rt_mod = tvm.build(self.optimized_mod, target=target) + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": True, + "tir.disable_cse_tir": True, + **(self.pass_context if self.pass_context else {}) + }): + if self.is_tir_backend(): + rt_mod = tvm.build(self.scheduled_ir_module, target=target) + elif self.is_tilelang_backend(): + # check only have one function in the module + if len(self.scheduled_ir_module.functions) > 1: + raise ValueError("Only support one function in the module") + tl_prim_func = list(self.scheduled_ir_module.functions.values())[0] + rt_mod, _ = tl.lower(tl_prim_func, target=target) + else: + raise ValueError(f"Unsupported backend: {self.backend}") except Exception: # noqa: F841 logger.debug( BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(self.__class__.__name__, target, @@ -156,12 +181,13 @@ def tvm_callback_cuda_postproc(code, _): if self.arch.platform == "CUDA": try: is_dynamic = ( - self.dynamic_range is not None and len(self.optimized_mod.functions) > 1) - self.wrapper.assign_optimized_module(self.optimized_mod) + self.dynamic_range is not None and + len(self.scheduled_ir_module.functions) > 1) + self.wrapper.assign_optimized_module(self.scheduled_ir_module) wrapped_source = self.wrapper.wrap( self.get_source(target, kenrel_only=True), is_dynamic) self.lib_generator.update_lib_code(wrapped_source) - self.lib_generator.compile_lib() + self.lib_generator.compile_lib(with_tl=self.is_tilelang_backend()) self.lib = self.lib_generator.load_lib() self.lib.init() @@ -172,10 +198,16 @@ def tvm_callback_cuda_postproc(code, _): return rt_mod + def scheduler_with_default(self, scheduler: BaseScheduler): + scheduled_ir_module = IRModule.from_expr(scheduler.with_default_config()) + if scheduled_ir_module is not None: + return scheduled_ir_module + return None + def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule: mod_for_opt = deepcopy(func_mod) with target: - optimized_mod = ( + scheduled_ir_module = ( bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable bitblas.gpu.Matmul(), bitblas.gpu.GEMV(), @@ -184,26 +216,29 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule bitblas.gpu.Fallback(), )(mod_for_opt)) - if optimized_mod is not None: - return optimized_mod + if scheduled_ir_module is not None: + return scheduled_ir_module return None - def _update_optimized_mod(self, optimized_mod: IRModule): - self.optimized_mod = optimized_mod + def _update_optimized_mod(self, scheduled_ir_module: IRModule): + self.scheduled_ir_module = scheduled_ir_module def _build_default_module(self, target: Target): try: - scheduled_mod = self.apply_default_schedule(self.prim_func_mod, target) + if self.is_tir_backend(): + scheduled_mod = self.apply_default_schedule(self.ir_module, target) + elif self.is_tilelang_backend(): + scheduled_mod = self.scheduler_with_default(self.scheduler) assert len(scheduled_mod.get_global_vars()) == 1, ( "The optimized module should only have one global variable for default schedule.") assert "main" in scheduled_mod, ( "The optimized module should have a function named 'main' for default schedule.") default_kernal_name = self.kernel_name_generator.generate() func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) - optimized_mod = tvm.IRModule({default_kernal_name: func}) - self._update_optimized_mod(optimized_mod) + scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) + self._update_optimized_mod(scheduled_ir_module) except Exception as apply_schedule_error: - self.optimized_mod = None + self.scheduled_ir_module = None logger.warning( APPLY_SCHEDULE_FAILED_MESSAGE.format(self.__class__.__name__, target, "default", apply_schedule_error)) @@ -232,15 +267,15 @@ def apply_fast_tuning_with_dynamic_range( topk: int = 20, dynamic_range: Dict[str, List[int]] = None, ): - optimized_mod = fast_tune_with_dynamic_range( + scheduled_ir_module = fast_tune_with_dynamic_range( func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range, kernel_name_generator=self.kernel_name_generator) - if optimized_mod is not None: - return optimized_mod + if scheduled_ir_module is not None: + return scheduled_ir_module return None def hardware_aware_finetune(self, @@ -252,7 +287,7 @@ def hardware_aware_finetune(self, dynamic_range = self.dynamic_range func = self.prim_func if dynamic_range is not None: - self.optimized_mod = self.apply_fast_tuning_with_dynamic_range( + self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( func, target, topk, dynamic_range) else: scheduled_mod, best_hint = self.apply_fast_tuning( @@ -263,8 +298,8 @@ def hardware_aware_finetune(self, "The optimized module should have a function named 'main' for default schedule.") default_kernal_name = self.kernel_name_generator.generate(best_hint) func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) - optimized_mod = tvm.IRModule({default_kernal_name: func}) - self._update_optimized_mod(optimized_mod) + scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) + self._update_optimized_mod(scheduled_ir_module) self._build_runtime_module(self.target) @@ -330,33 +365,17 @@ def profile_latency(self, dynamic_symbolic_constraints: Optional[Dict] = None) - dynamic_symbolic_constraints = {} profile_tensors = self.get_profile_tensors(dynamic_symbolic_constraints) latency = self.time_evaluator(*profile_tensors).mean * 1e3 - # release the memory + # release the memory of profile tensors for tensor in profile_tensors: del tensor return latency - def _tensor_adapter(self, tensor, device): - import torch - from torch.utils.dlpack import to_dlpack - - if isinstance(tensor, tvm.te.Tensor): - return tensor - elif isinstance(tensor, torch.Tensor): - return tvm.runtime.ndarray.from_dlpack(to_dlpack(tensor)) - elif isinstance(tensor, np.ndarray): - return tvm.nd.array(tensor, device=device) - else: - raise RuntimeError("Not supported type: ", type(tensor)) - def _forward_from_torch_func(self, *args): # Torch func is not reliable as the runtime overhead dlpack # is not negaliable, ref to https://discuss.tvm.apache.org/t/strange-overhead-of-tvm-runtime-ndarray-from-dlpack/16516 self.torch_func(*args) return args[-1] - def forward(self, *args): - return self._forward_from_torch_func(*args) - def _forward_from_prebuild_lib(self, *args, stream=0): ctypes_args = [ ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args @@ -364,14 +383,14 @@ def _forward_from_prebuild_lib(self, *args, stream=0): ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) - def call_lib(self, *args, stream=0): - self.lib.call(*args, ctypes.c_void_p(stream)) + def forward(self, *args): + return self._forward_from_torch_func(*args) def __call__(self, *args: Any) -> Any: return self.forward(*args) def update_func(self, func: PrimFunc): - self.prim_func_mod["main"] = func + self.ir_module["main"] = func def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None): if rt_mod is not None: @@ -382,26 +401,36 @@ def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None): if srcpath is not None: assert self.lib_generator is not None, "lib_generator is not initialized" self.lib_generator.set_src_path(srcpath) + # TODO(lei): update the lib code from srcpath if libpath is not None: assert self.lib_generator is not None, "lib_generator is not initialized" self.lib_generator.set_lib_path(libpath) self.lib = ctypes.CDLL(libpath) self.lib.init() - # TODO: update the lib code from srcpath def cleanup(self): raise NotImplementedError - @abstractmethod - def _select_implementation(self) -> IRModule: - pass + def check_only_tir_backend(self): + assert self.is_tir_backend(), "Only support tir backend" + + def check_only_tilelang_backend(self): + assert self.is_tilelang_backend(), "Only support tilelang backend" + + def _select_implementation(self) -> Optional[IRModule]: + # only roller based template schedule + raise NotImplementedError + + def _select_scheduler(self) -> Optional[BaseScheduler]: + # only tilelang based template schedule + raise NotImplementedError @property def prim_func(self): - if len(self.prim_func_mod.get_global_vars()) == 1: - return self.prim_func_mod[self.prim_func_mod.get_global_vars()[0]] - elif "main" in self.prim_func_mod: - return self.prim_func_mod["main"] + if len(self.ir_module.get_global_vars()) == 1: + return self.ir_module[self.ir_module.get_global_vars()[0]] + elif "main" in self.ir_module: + return self.ir_module["main"] else: raise ValueError("Unable to determine primary function.") diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 0f0b361c5..f3db7d88a 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -4,7 +4,7 @@ import tvm.tl.language as T from typing import Union -from bitblas.ops.operator import TransformKind +from bitblas.ops.common import TransformKind from tvm import DataType from tvm.runtime import convert from .utils import ( diff --git a/bitblas/utils/post_process.py b/bitblas/utils/post_process.py index cabee6be1..4eba191dc 100644 --- a/bitblas/utils/post_process.py +++ b/bitblas/utils/post_process.py @@ -6,7 +6,7 @@ def match_global_kernel(source: str) -> int: pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" matched = re.findall(pattern, source) - assert len(matched) > 1 # may have statement before kernel + assert len(matched) >= 1 # may have statement before kernel return source.index(matched[0]) @@ -28,6 +28,7 @@ def tensor_remove_make_int4(source: str) -> str: ) return source + def tensor_remove_make_int2(source: str) -> str: # remove make_int4 with 16 signed char arguments # TODO(lei): this is a stuff that should be fixed in the tvm in the future diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py index 69a08dfdc..e3fe4c1cb 100644 --- a/bitblas/utils/rtmod_analysis.py +++ b/bitblas/utils/rtmod_analysis.py @@ -6,9 +6,72 @@ from tvm.driver import lower from tvm.target import Target from typing import Tuple, List +from tvm import tir +from tvm import tl +from tvm.tl.engine import is_device_call -def get_annotated_device_mod(mod: IRModule, target: Target) -> "IRModule": +def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule": + target_host = tvm.target.Target("llvm -keys=cpu") + target = tvm.target.Target(target, target_host) + mod = tir.transform.BindTarget(target)(mod) + + mod = tl.transform.FrontendLegalize()(mod) + mod = tir.transform.Simplify()(mod) + mod = tl.transform.LayoutInference()(mod) + mod = tl.transform.LowerTileOp()(mod) + mod = tir.transform.Simplify()(mod) + + if target.arch == "sm_90": + mod = tl.transform.WarpSpecializedPipeline()(mod) + else: + mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tl.transform.PipelinePlanning()(mod) + mod = tl.transform.InjectSoftwarePipeline()(mod) + + mod = tir.transform.LowerOpaqueBlock()(mod) + mod = tir.transform.FlattenBuffer()(mod) + mod = tir.transform.NarrowDataType(32)(mod) + mod = tir.transform.Simplify()(mod) + + mod = tir.transform.VectorizeLoop()(mod) + mod = tir.transform.StorageRewrite()(mod) + mod = tir.transform.UnrollLoop()(mod) + mod = tir.transform.RenormalizeSplitPattern()(mod) + mod = tir.transform.Simplify()(mod) + mod = tir.transform.RemoveNoOp()(mod) + mod = tir.transform.RewriteUnsafeSelect()(mod) + mod = tir.transform.HoistIfThenElse()(mod) + + mod = tir.transform.VerifyMemory()(mod) + mod = tir.transform.AnnotateEntryFunc()(mod) + mod = tir.transform.ThreadSync("shared")(mod) + # TODO(lei): This is a hack to make sure the + # thread level allreduce pass can be applied + # in TL. As Tl only use one thread dimension + # the var binding information will be lost + # in the lowering process with Legalization + # and Simplify pass. + # We can find a way better to create var instead + # of putting the LowerThreadAllreduce before + # the Legalization. + mod = tir.transform.LowerThreadAllreduce()(mod) + mod = tir.transform.ThreadSync("shared.dyn")(mod) + mod = tl.transform.LowerHopperIntrin()(mod) + mod = tir.transform.InjectPTXAsyncCopy()(mod) + + mod = tir.transform.AnnotateDeviceRegions()(mod) + mod = tir.transform.SplitHostDevice()(mod) + mod = tir.transform.MergeSharedMemoryAllocations()(mod) + mod = tir.transform.MakePackedAPI()(mod) + mod = tir.transform.LowerDeviceKernelLaunch()(mod) + + device_mod = tir.transform.Filter(is_device_call)(mod) + + return device_mod + + +def get_annotated_device_mod_from_tir(mod: IRModule, target: Target) -> "IRModule": """ Lower the given IRModule and create a device module for the specified target. @@ -50,6 +113,15 @@ def get_annotated_device_mod(mod: IRModule, target: Target) -> "IRModule": return device_mod +def get_annotated_device_mod(mod: IRModule, target: Target, backend="tir") -> "IRModule": + if backend == "tir": + return get_annotated_device_mod_from_tir(mod, target) + elif backend == "tl": + return get_annotated_device_mod_from_tl(mod, target) + else: + raise ValueError("Unsupported backend: {}".format(backend)) + + def get_thread_block_information(mod: IRModule) -> Tuple[List[int], List[int]]: """ Extracts the thread block and grid dimensions for the reduction block within a given IRModule. diff --git a/docs/ExtendOperatorsWithDSL.md b/docs/ExtendOperatorsWithDSL.md index 8c717b43e..ec62356b5 100644 --- a/docs/ExtendOperatorsWithDSL.md +++ b/docs/ExtendOperatorsWithDSL.md @@ -137,7 +137,7 @@ class MatmulNT: from bitblas import fast_tune_with_dynamic_range # Tune with dynamic symbolic -optimized_mod = fast_tune_with_dynamic_range( +scheduled_ir_module = fast_tune_with_dynamic_range( func, target, topk=topk, parallel_build=True, dynamic_range={ "M": [1, 1024] diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index f65ce8066..c9bec630f 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -39,8 +39,9 @@ def matmul_backend_code_wrap( ) matmul = Matmul(config=matmul_config, enable_tuning=False) backend = TIRWrapper(arch=matmul.arch) - backend.assign_optimized_module(matmul.optimized_mod) - is_dynamic = (matmul.dynamic_range is not None and len(matmul.optimized_mod.functions) > 1) + backend.assign_optimized_module(matmul.scheduled_ir_module) + is_dynamic = ( + matmul.dynamic_range is not None and len(matmul.scheduled_ir_module.functions) > 1) wrapped_code = backend.wrap(matmul.get_source(kenrel_only=True), is_dynamic=is_dynamic) assert "void call" in wrapped_code diff --git a/testing/python/operators/test_general_matmul_ops_backend_tl.py b/testing/python/operators/test_general_matmul_ops_backend_tl.py new file mode 100644 index 000000000..90ed00c6e --- /dev/null +++ b/testing/python/operators/test_general_matmul_ops_backend_tl.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +# fmt: off +def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + assert get_codegen_result(matmul) + + +def test_matmul_codegen_default(): + matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, + False, False, None), + matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, + -1, False, False, None), + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 45558ba69..1281361aa 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -24,8 +24,8 @@ def assert_matmul_blocked_correctness(M, block_K=32, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16", num_stages=2, threads=128, @@ -39,8 +39,8 @@ def assert_matmul_blocked_correctness(M, block_K=block_K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, num_stages=num_stages, threads=threads, @@ -53,8 +53,8 @@ def assert_matmul_blocked_correctness(M, # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) @@ -75,8 +75,8 @@ def assert_matmul_macro_tensorcore_correctness( M, N, K, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", trans_A=False, trans_B=True, accum_dtype="float16", @@ -92,8 +92,8 @@ def assert_matmul_macro_tensorcore_correctness( M=M, N=N, K=K, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, trans_A=trans_A, trans_B=trans_B, accum_dtype=accum_dtype, @@ -111,8 +111,8 @@ def assert_matmul_macro_tensorcore_correctness( # src_code represents generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) @@ -133,8 +133,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( M, N, K, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", trans_A=False, trans_B=True, accum_dtype="float16", @@ -150,8 +150,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( M=M, N=N, K=K, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, trans_A=trans_A, trans_B=trans_B, accum_dtype=accum_dtype, @@ -170,8 +170,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 18115f450..10e9ade7c 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -21,8 +21,8 @@ def assert_matmul_blocked_with_default_correctness(M, K, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16"): matmul = MatmulScheduler( M=M, @@ -30,8 +30,8 @@ def assert_matmul_blocked_with_default_correctness(M, K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).with_default_config() @@ -41,8 +41,8 @@ def assert_matmul_blocked_with_default_correctness(M, # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) @@ -67,8 +67,8 @@ def assert_matmul_blocked_apply_config_correctness(M, block_K=32, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16", num_stages=2, threads=128, @@ -79,8 +79,8 @@ def assert_matmul_blocked_apply_config_correctness(M, K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).apply_config( block_M=block_M, @@ -97,8 +97,8 @@ def assert_matmul_blocked_apply_config_correctness(M, # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) @@ -120,8 +120,8 @@ def assert_matmul_fine_grained_with_default_correctness(M, K, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16"): matmul = MatmulFineGrainScheduler( @@ -130,8 +130,8 @@ def assert_matmul_fine_grained_with_default_correctness(M, K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).with_default_config() @@ -141,9 +141,9 @@ def assert_matmul_fine_grained_with_default_correctness(M, # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) @@ -155,7 +155,7 @@ def assert_matmul_fine_grained_with_default_correctness(M, assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) + ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) @@ -167,8 +167,8 @@ def assert_matmul_fine_grained_apply_config_correctness( K, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16", block_row_warps=1, block_col_warps=1, @@ -185,8 +185,8 @@ def assert_matmul_fine_grained_apply_config_correctness( K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).apply_config( block_row_warps=block_row_warps, @@ -204,8 +204,8 @@ def assert_matmul_fine_grained_apply_config_correctness( # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) @@ -227,8 +227,8 @@ def assert_matmul_weight_propagation_with_default_correctness(M, K, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16"): matmul = MatmulWeightPropagationScheduler( @@ -237,8 +237,8 @@ def assert_matmul_weight_propagation_with_default_correctness(M, K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).with_default_config() @@ -248,9 +248,9 @@ def assert_matmul_weight_propagation_with_default_correctness(M, # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( M=N, @@ -273,7 +273,7 @@ def assert_matmul_weight_propagation_with_default_correctness(M, assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) + ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) @@ -285,8 +285,8 @@ def assert_matmul_weight_propagation_apply_config_correctness( K, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16", block_row_warps=1, block_col_warps=1, @@ -303,8 +303,8 @@ def assert_matmul_weight_propagation_apply_config_correctness( K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).apply_config( block_row_warps=block_row_warps, @@ -322,9 +322,9 @@ def assert_matmul_weight_propagation_apply_config_correctness( # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - 0.5 - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) ladder_permutate_config = bitblas.ops.LadderPermutateConfig( M=N, @@ -347,7 +347,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, dtypeC)) + ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index 1e6bd6466..87c685e08 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -13,8 +13,8 @@ def assert_scheduler_simplify(M, K, trans_A=False, trans_B=True, - dtypeAB="float16", - dtypeC="float16", + in_dtype="float16", + out_dtype="float16", accum_dtype="float16"): matmul = MatmulScheduler( M=M, @@ -22,8 +22,8 @@ def assert_scheduler_simplify(M, K=K, trans_A=trans_A, trans_B=trans_B, - dtypeAB=dtypeAB, - dtypeC=dtypeC, + in_dtype=in_dtype, + out_dtype=out_dtype, accum_dtype=accum_dtype, ).deactivate_simplify().with_default_config() diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 27af4bd54..1f9f44ab5 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -39,8 +39,8 @@ def matmul( block_M, block_N, block_K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, num_stages, threads, @@ -58,16 +58,16 @@ def matmul( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), + A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), dtypeC), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_local([8], storage_dtype) - B_dequantize_local = T.alloc_local([16], dtypeAB) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB) + B_dequantize_local = T.alloc_local([16], in_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) tx = T.thread_binding(0, threads, thread="threadIdx.x") @@ -89,7 +89,7 @@ def main( num_bits, B_local[v // 2], v % 2, - dtype=dtypeAB, + dtype=in_dtype, ) for v in T.vectorized(0, 8): vi = (i * threads * 8 + tx * 8 + v) // (block_K) @@ -105,8 +105,8 @@ def run_gemm( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, block_M, block_N, @@ -121,8 +121,8 @@ def run_gemm( block_M, block_N, block_K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, num_stages, num_threads, @@ -144,7 +144,7 @@ def ref_program(A, qB): for j in range(B.shape[1]): B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) + C = C.to(torch.__getattribute__(out_dtype)) return C mod.assert_allclose(ref_program) @@ -154,16 +154,16 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, transform_b, ): - assert dtypeAB in [ + assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" - assert dtypeC in [ + assert out_dtype in [ "float16", "float32", "int32", @@ -174,7 +174,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( micro_size_x = micro_size_y = micro_size_k = 16 - if dtypeC == "int32": + if out_dtype == "int32": micro_size_k = 32 # This is a debug config @@ -193,11 +193,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles - block_K = 32 if dtypeAB == "float16" else 64 + block_K = 32 if in_dtype == "float16" else 64 chunk = block_K // reduce_k is_smooth_a = False - can_swizzle = block_K * DataType(dtypeAB).bits == 512 + can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) pad_factor = 8 @@ -226,8 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( # MMA Wrapper to Auto Generate Code for MMA mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -246,20 +246,20 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), + A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), dtypeC), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) - B_dequantize_local = T.alloc_local((warp_cols * local_size), dtypeAB) + B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -365,12 +365,12 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct N, K, in_dtype, - dtypeC, + out_dtype, accum_dtype, transform_b, ): matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( - M, N, K, in_dtype, dtypeC, accum_dtype, transform_b) + M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index 9af34e037..4d7be551b 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -32,15 +32,15 @@ def transform_func(i, j): def tl_matmul_macro( N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, ): - assert dtypeAB in [ + assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" - assert dtypeC in [ + assert out_dtype in [ "float16", "float32", "int32", @@ -48,7 +48,7 @@ def tl_matmul_macro( micro_size_x = micro_size_y = micro_size_k = 16 - if dtypeC == "int32": + if out_dtype == "int32": micro_size_k = 32 # This is a debug config @@ -56,7 +56,7 @@ def tl_matmul_macro( block_col_warps = 1 warp_row_tiles = 16 warp_col_tiles = 16 - chunk = 32 if dtypeAB == "float16" else 64 + chunk = 32 if in_dtype == "float16" else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -86,8 +86,8 @@ def tl_matmul_macro( # MMA Wrapper to Auto Generate Code for MMA mma_emitter = TensorCoreIntrinEmitter( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -100,17 +100,17 @@ def tl_matmul_macro( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -172,8 +172,8 @@ def main( return main -def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, dtypeC, accum_dtype): - matmul = tl_matmul_macro(N, K, in_dtype, dtypeC, accum_dtype) +def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() @@ -202,8 +202,8 @@ def tl_matmul_block( block_K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, num_stages, threads, @@ -217,11 +217,11 @@ def tl_matmul_block( import tvm.tl.language as T @T.prim_func - def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), - dtypeC)): + def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( + (M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -245,8 +245,8 @@ def assert_tl_matmul_block_correctness( K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, block_M, block_N, @@ -262,17 +262,17 @@ def assert_tl_matmul_block_correctness( block_K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, num_stages, num_threads, ) mod, params = TL.lower(program) - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod(A, B, C) @@ -285,7 +285,7 @@ def ref_program(A, B): if trans_B: B = B.T C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) + C = C.to(torch.__getattribute__(out_dtype)) return C # Get Reference Result @@ -300,8 +300,8 @@ def tl_matmul_block_all_dynamic( block_K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, num_stages, threads, @@ -318,11 +318,11 @@ def tl_matmul_block_all_dynamic( import tvm.tl.language as T @T.prim_func - def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), - dtypeC)): + def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( + (M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -346,8 +346,8 @@ def assert_tl_matmul_block_all_dynamic_correctness( K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, block_M, block_N, @@ -361,17 +361,17 @@ def assert_tl_matmul_block_all_dynamic_correctness( block_K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, num_stages, num_threads, ) mod, params = TL.lower(program) - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod(A, B, C) @@ -385,7 +385,7 @@ def ref_program(A, B): if trans_B: B = B.T C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) + C = C.to(torch.__getattribute__(out_dtype)) return C # Get Reference Result diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index c75e4ccc1..b387f916b 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -15,8 +15,8 @@ def matmul( block_K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, num_stages, threads, @@ -29,11 +29,11 @@ def matmul( import tvm.tl.language as T @T.prim_func - def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), - dtypeC)): + def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( + (M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -57,8 +57,8 @@ def run_gemm( K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, block_M, block_N, @@ -75,8 +75,8 @@ def run_gemm( block_K, trans_A, trans_B, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, dtypeAccum, num_stages, num_threads, @@ -92,7 +92,7 @@ def ref_program(A, B): if trans_B: B = B.T C = torch.matmul(A.to(torch.float), B.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) + C = C.to(torch.__getattribute__(out_dtype)) return C mod.assert_allclose(ref_program) diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 9d797ff66..9ef592d2d 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -37,15 +37,15 @@ def tl_matmul( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, ): - assert dtypeAB in [ + assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" - assert dtypeC in [ + assert out_dtype in [ "float16", "float32", "int32", @@ -53,7 +53,7 @@ def tl_matmul( micro_size_x = micro_size_y = micro_size_k = 16 - if dtypeC == "int32": + if out_dtype == "int32": micro_size_k = 32 # This is a debug config @@ -61,7 +61,7 @@ def tl_matmul( block_col_warps = 1 warp_row_tiles = 16 warp_col_tiles = 16 - chunk = 32 if dtypeAB == "float16" else 64 + chunk = 32 if in_dtype == "float16" else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -90,8 +90,8 @@ def tl_matmul( # MMA Wrapper to Auto Generate Code for MMA mma_emitter = TensorCoreIntrinEmitter( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -104,17 +104,17 @@ def tl_matmul( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -176,8 +176,8 @@ def main( return main -def assert_tl_matmul_correctness(M, N, K, in_dtype, dtypeC, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, dtypeC, accum_dtype) +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() @@ -207,15 +207,15 @@ def tl_matmul_with_block_reduce( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, ): - assert dtypeAB in [ + assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" - assert dtypeC in [ + assert out_dtype in [ "float16", "float32", "int32", @@ -223,7 +223,7 @@ def tl_matmul_with_block_reduce( micro_size_x = micro_size_y = micro_size_k = 16 - if dtypeC == "int32": + if out_dtype == "int32": micro_size_k = 32 # This is a debug config @@ -238,7 +238,7 @@ def tl_matmul_with_block_reduce( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles - block_K = 32 if dtypeAB == "float16" else 64 + block_K = 32 if in_dtype == "float16" else 64 reduce_k = 2 chunk = block_K // reduce_k @@ -260,8 +260,8 @@ def tl_matmul_with_block_reduce( warp_cols = warp_col_tiles // micro_size_y mma_emitter = TensorCoreIntrinEmitter( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -274,17 +274,17 @@ def tl_matmul_with_block_reduce( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype) @@ -371,8 +371,8 @@ def main( return main -def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, dtypeC, accum_dtype): - matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, dtypeC, accum_dtype) +def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, out_dtype, accum_dtype) mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() @@ -402,16 +402,16 @@ def tl_matmul_with_ladder_weight_only_transform( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, transform_b, ): - assert dtypeAB in [ + assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" - assert dtypeC in [ + assert out_dtype in [ "float16", "float32", "int32", @@ -419,7 +419,7 @@ def tl_matmul_with_ladder_weight_only_transform( micro_size_x = micro_size_y = micro_size_k = 16 - if dtypeC == "int32": + if out_dtype == "int32": micro_size_k = 32 # This is a debug config @@ -431,7 +431,7 @@ def tl_matmul_with_ladder_weight_only_transform( warp_row_tiles = micro_size_x * warp_rows warp_col_tiles = micro_size_y * warp_cols - chunk = 64 if dtypeAB == "float16" else 128 + chunk = 64 if in_dtype == "float16" else 128 shared_scope = "shared.dyn" # Pipeline Stage @@ -442,7 +442,7 @@ def tl_matmul_with_ladder_weight_only_transform( block_K = chunk is_smooth_a = False - can_swizzle = block_K * DataType(dtypeAB).bits == 512 + can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) pad_factor = 8 @@ -465,8 +465,8 @@ def tl_matmul_with_ladder_weight_only_transform( # MMA Wrapper to Auto Generate Code for MMA mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -479,17 +479,17 @@ def tl_matmul_with_ladder_weight_only_transform( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) - B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) + B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -544,9 +544,9 @@ def main( return main -def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_dtype, dtypeC, +def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b): - matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, dtypeC, accum_dtype, + matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) mod, params = TL.lower(matmul) @@ -588,16 +588,16 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, - dtypeAB, - dtypeC, + in_dtype, + out_dtype, accum_dtype, transform_b, ): - assert dtypeAB in [ + assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" - assert dtypeC in [ + assert out_dtype in [ "float16", "float32", "int32", @@ -608,7 +608,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( micro_size_x = micro_size_y = micro_size_k = 16 - if dtypeC == "int32": + if out_dtype == "int32": micro_size_k = 32 # This is a debug config @@ -627,11 +627,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles - block_K = 32 if dtypeAB == "float16" else 64 + block_K = 32 if in_dtype == "float16" else 64 chunk = block_K // reduce_k is_smooth_a = False - can_swizzle = block_K * DataType(dtypeAB).bits == 512 + can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) pad_factor = 8 @@ -660,8 +660,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( # MMA Wrapper to Auto Generate Code for MMA mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=dtypeAB, - b_dtype=dtypeAB, + a_dtype=in_dtype, + b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, @@ -680,20 +680,20 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), + A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), dtypeC), + C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), in_dtype) B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) - B_dequantize_local = T.alloc_local((warp_cols * local_size), dtypeAB) + B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -799,12 +799,12 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct N, K, in_dtype, - dtypeC, + out_dtype, accum_dtype, transform_b, ): matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( - M, N, K, in_dtype, dtypeC, accum_dtype, transform_b) + M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() From c485b68a9982caa0c281997e1e31c7bbea38a054 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 29 Sep 2024 08:08:44 +0000 Subject: [PATCH 7/8] test fix --- bitblas/ops/operator.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index eb02fdf70..b723eabf8 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -59,6 +59,25 @@ def generate(self, hint: Hint = None) -> str: pass +class DefaultKernelNameGenerator(BaseKernelNameGenerator): + + DEFAULT_PREFIX = "main" + + def __init__(self, config: OperatorConfig, name: str): + self.DEFAULT_PREFIX = name + super().__init__(config) + + def generate(self, hint: Hint = None) -> str: + # hint is not used + assert hint is not None + return self.DEFAULT_PREFIX + + def is_valid_config(self, config: OperatorConfig) -> bool: + # hint is not used + assert config is not None + return True + + class Operator(object): def __init__(self, @@ -105,7 +124,7 @@ def is_tilelang_backend(self): return self.backend == "tl" def get_kernel_name_generator(self) -> Optional[BaseKernelNameGenerator]: - raise NotImplementedError + return DefaultKernelNameGenerator(self.config, self.name) def get_source(self, target: Optional[Target] = None, kenrel_only=False) -> str: if target is None: From ebe42a6f085ef6fe6f82ae01884c229d9a8866fb Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 29 Sep 2024 10:26:05 +0000 Subject: [PATCH 8/8] hardware tuning demo --- 3rdparty/tvm | 2 +- bitblas/base/utils.py | 2 +- bitblas/ops/base_scheduler.py | 5 + .../general_matmul/tilelang/dense/__init__.py | 8 +- .../tilelang/dense/matmul_simt.py | 62 ++++ .../dense/{matmul.py => matmul_tensorcore.py} | 19 +- bitblas/ops/operator.py | 121 +++++--- bitblas/tl/tuner.py | 284 ++++++++++++++++++ .../test_general_matmul_ops_backend_tl.py | 36 +++ 9 files changed, 490 insertions(+), 49 deletions(-) create mode 100644 bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py rename bitblas/ops/general_matmul/tilelang/dense/{matmul.py => matmul_tensorcore.py} (97%) create mode 100644 bitblas/tl/tuner.py diff --git a/3rdparty/tvm b/3rdparty/tvm index d0c06c764..1fa647dbf 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d0c06c7641956a3bd9ab1174ed05a1aa2a624d2a +Subproject commit 1fa647dbff6a273cbdf2a6f0a64b3478ba553223 diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 90fab86d0..2b887ba2d 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -193,7 +193,7 @@ def _apply_schedule(f, c): sch = None return sch - with ThreadPoolExecutor(max_workers=4) as scheduler: + with ThreadPoolExecutor(max_workers=max_workers) as scheduler: futures = {scheduler.submit(_apply_schedule, func, config) for config in configs} for future in as_completed(futures, timeout=timeout): _sched.append(future.result()) diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py index 72a52937b..72ee1b29c 100644 --- a/bitblas/ops/base_scheduler.py +++ b/bitblas/ops/base_scheduler.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from tvm.tir.transform import Simplify from abc import ABC, abstractmethod +from bitblas.base.arch import TileDevice @dataclass @@ -20,6 +21,10 @@ def Simplify(stmt: Union[PrimFunc, IRModule]): else: raise ValueError(f"Unsupported type: {type(stmt)}") + def get_hardware_aware_configs(self, arch: TileDevice = None): + raise NotImplementedError( + f"{self.__class__.__name__} does not support hardware-aware tuning for {arch}") + def activate_simplify(self): self._enable_simplify = True return self diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 2a929355c..9ab9b6990 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -1,13 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .matmul import ( +from .matmul_simt import ( + MatmulFineGrainSIMTScheduler, # noqa: F401 +) + +from .matmul_tensorcore import ( matmul_blocked, # noqa: F401 matmul_macro_tensorcore, # noqa: F401 matmul_macro_tensorcore_weight_propagation_level_ldmatrix # noqa: F401 ) -from .matmul import ( +from .matmul_tensorcore import ( MatmulScheduler, # noqa: F401 MatmulFineGrainScheduler, # noqa: F401 MatmulWeightPropagationScheduler, # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py new file mode 100644 index 000000000..bc091f910 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional +from bitblas.tl.utils import ( + get_mma_micro_size, + make_swizzle_layout, +) + +from bitblas.ops.base_scheduler import BaseScheduler + +from dataclasses import dataclass + + +@dataclass +class MatmulFineGrainSIMTScheduler(BaseScheduler): + # Fine-grained matrix multiplication scheduler + # Allows for more detailed configuration. + + # Operation Configuration + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None + in_dtype: str = "float16" + out_dtype: str = "float16" + trans_A: bool = False + trans_B: bool = True + accum_dtype: str = "float16" + + # Tensor Core Warp Configuration + block_row_warps: int = 2 + block_col_warps: int = 2 + warp_row_tiles: int = 32 + warp_col_tiles: int = 32 + chunk: int = 32 # Usually determines the K-dimension split size + + # Tiling and Other Optimization Parameters + num_stages: int = 2 + enable_rasterization: bool = False + + def with_default_config(self): + raise NotImplementedError + + def apply_config( + self, + ): + + # M, N, K = self.M, self.N, self.K + # trans_A, trans_B = self.trans_A, self.trans_B + # in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + + raise NotImplementedError + + + def __post_init__(self): + # Validate the matrix transpose settings + assert self.trans_A is False, "Currently only support Matrix A not transposed" + assert self.trans_B is True, "Currently only support Matrix B transposed" + + return diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py similarity index 97% rename from bitblas/ops/general_matmul/tilelang/dense/matmul.py rename to bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 1c28ff695..35a200527 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import itertools from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T @@ -15,7 +16,7 @@ ) from bitblas.ops.common import TransformKind from bitblas.ops.base_scheduler import BaseScheduler - +from bitblas.base.arch import CUDA from dataclasses import dataclass @@ -40,6 +41,22 @@ class MatmulScheduler(BaseScheduler): threads: int = 128 enable_rasterization: bool = False # Enhance L2 Locality + def get_configs_sm80(self): + num_stages = 2 + configs = [ + {'block_M': 128, 'block_N': 256, 'block_K': 32, 'threads': 128}, + {'block_M': 256, 'block_N': 128, 'block_K': 32, 'threads': 128}, + {'block_M': 128, 'block_N': 128, 'block_K': 32, 'threads': 128}, + ] + configs = [{**c, 'num_stages': num_stages} for c in configs] + return configs + + def get_hardware_aware_configs(self, arch: CUDA = None): + # TODO(lei): implement only for SM80 Currently + sm_version: int = int(arch.sm_partition) + assert sm_version is not None, "Please provide a valid CUDA Arch" + return self.get_configs_sm80() + def with_default_config(self): block_M = getattr(self, "block_M", 64) block_N = getattr(self, "block_N", 64) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index b723eabf8..eb173352f 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -10,9 +10,10 @@ from tvm.contrib.dlpack import to_pytorch_func import bitblas import ctypes -from typing import (List, Dict, Any, Optional, Tuple, Literal, Callable) +from typing import List, Dict, Any, Optional, Tuple, Literal, Callable import numpy as np from bitblas.base import fast_tune, fast_tune_with_dynamic_range +from bitblas.tl.tuner import apply_and_build as tl_apply_and_build from copy import deepcopy from bitblas.ops.base_scheduler import BaseScheduler from bitblas.base.arch import get_arch, TileDevice @@ -38,6 +39,7 @@ @dataclass(frozen=True) class OperatorConfig: """Base class for operator configurations. Used for typing.""" + pass @@ -55,7 +57,7 @@ def is_valid_config(self, config: OperatorConfig): @abstractmethod def generate(self, hint: Hint = None) -> str: - '''Generate the kernel name based on the config and hint''' + """Generate the kernel name based on the config and hint""" pass @@ -73,18 +75,20 @@ def generate(self, hint: Hint = None) -> str: return self.DEFAULT_PREFIX def is_valid_config(self, config: OperatorConfig) -> bool: - # hint is not used + # config is not used assert config is not None return True class Operator(object): - def __init__(self, - name, - config: OperatorConfig, - target: Target = None, - backend: Literal["tir", "tl"] = "tir"): + def __init__( + self, + name, + config: OperatorConfig, + target: Target = None, + backend: Literal["tir", "tl"] = "tir", + ): if isinstance(target, str): target = Target(target) self.name = name @@ -169,7 +173,7 @@ def tvm_callback_cuda_postproc(code, _): config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, - **(self.pass_context if self.pass_context else {}) + **(self.pass_context if self.pass_context else {}), }): if self.is_tir_backend(): rt_mod = tvm.build(self.scheduled_ir_module, target=target) @@ -183,9 +187,12 @@ def tvm_callback_cuda_postproc(code, _): raise ValueError(f"Unsupported backend: {self.backend}") except Exception: # noqa: F841 logger.debug( - BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(self.__class__.__name__, target, - "optimized", - "Failed to build optimized module")) + BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format( + self.__class__.__name__, + target, + "optimized", + "Failed to build optimized module", + )) else: # For non-CUDA platforms or when no optimized function is available, build with the primary function rt_mod = tvm.build(self.prim_func, target=target, name=self.name) @@ -248,10 +255,12 @@ def _build_default_module(self, target: Target): scheduled_mod = self.apply_default_schedule(self.ir_module, target) elif self.is_tilelang_backend(): scheduled_mod = self.scheduler_with_default(self.scheduler) - assert len(scheduled_mod.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") - assert "main" in scheduled_mod, ( - "The optimized module should have a function named 'main' for default schedule.") + assert ( + len(scheduled_mod.get_global_vars()) == 1 + ), "The optimized module should only have one global variable for default schedule." + assert ( + "main" in scheduled_mod + ), "The optimized module should have a function named 'main' for default schedule." default_kernal_name = self.kernel_name_generator.generate() func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) @@ -267,54 +276,77 @@ def _build_default_module(self, target: Target): def post_process(self, code: str) -> str: return code - def apply_fast_tuning(self, - func: PrimFunc, - target: Target, - topk: int = 20, - parallel_build=True) -> Tuple[IRModule, Hint]: - _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) - # annotate the best pass context - # TODO(lei): actually we should remove this by enable pass through - # annotation in the func's attribute. - self.pass_context = best.config.pass_context - return ((best.sch.mod, best.config) if best is not None else (None, None)) + def get_tl_tuning_config(self): + assert self.is_tilelang_backend(), "Only support tilelang backend" + return self.scheduler.get_hardware_aware_configs(self.arch) + + def apply_fast_tuning( + self, + func_or_scheduler: PrimFunc, + target: Target, + topk: int = 20, + parallel_build=True, + ) -> Tuple[IRModule, Hint]: + if self.is_tir_backend(): + _, best = fast_tune(func_or_scheduler, target, topk=topk, parallel_build=parallel_build) + # annotate the best pass context + # TODO(lei): actually we should remove this by enable pass through + # annotation in the func's attribute. + self.pass_context = best.config.pass_context + return (best.sch.mod, best.config) if best is not None else (None, None) + elif self.is_tilelang_backend(): + # Finetune the schedule + tuning_configs = self.get_tl_tuning_config() + _, best = tl_apply_and_build( + func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=False) + # Return the best Config as Hint + return (best.sch.mod, best.config) if best is not None else (None, None) def apply_fast_tuning_with_dynamic_range( self, - func: PrimFunc, + func_or_scheduler: PrimFunc, target: Target, topk: int = 20, dynamic_range: Dict[str, List[int]] = None, ): scheduled_ir_module = fast_tune_with_dynamic_range( - func, + func_or_scheduler, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range, - kernel_name_generator=self.kernel_name_generator) + kernel_name_generator=self.kernel_name_generator, + ) if scheduled_ir_module is not None: return scheduled_ir_module return None - def hardware_aware_finetune(self, - topk: int = 20, - target: Optional[tvm.target.Target] = None, - parallel_build=True): + def hardware_aware_finetune( + self, + topk: int = 20, + target: Optional[tvm.target.Target] = None, + parallel_build=True, + ): if target is None: target = self.target dynamic_range = self.dynamic_range - func = self.prim_func if dynamic_range is not None: - self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( - func, target, topk, dynamic_range) + if self.is_tir_backend(): + func = self.prim_func + self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( + func, target, topk, dynamic_range) + elif self.is_tilelang_backend(): + raise NotImplementedError("Not support dynamic range for tilelang backend") else: + func_or_scheduler = (self.prim_func if self.is_tir_backend() else self.scheduler) scheduled_mod, best_hint = self.apply_fast_tuning( - func, target, topk, parallel_build=parallel_build) - assert len(scheduled_mod.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") - assert "main" in scheduled_mod, ( - "The optimized module should have a function named 'main' for default schedule.") + func_or_scheduler, target, topk, parallel_build=parallel_build) + assert ( + len(scheduled_mod.get_global_vars()) == 1 + ), "The optimized module should only have one global variable for default schedule." + assert ( + "main" in scheduled_mod + ), "The optimized module should have a function named 'main' for default schedule." default_kernal_name = self.kernel_name_generator.generate(best_hint) func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) @@ -341,8 +373,9 @@ def var_warpper(v): for i in func.attrs["opt_shapes"][v.name]: avg_shape += i.value avg_shape = avg_shape // len(func.attrs["opt_shapes"][v.name]) - _info_message = f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, "\ - f"use average shape {avg_shape}" + _info_message = ( + f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, " + f"use average shape {avg_shape}") logger.info(_info_message) return avg_shape else: diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py new file mode 100644 index 000000000..8f9ab4f84 --- /dev/null +++ b/bitblas/tl/tuner.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm +import os +from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +from typing import List, Tuple, Optional, Dict, Union, Literal, Callable +from tvm import tir, IRModule +from tvm.runtime import Module +from tvm.tir import Schedule +from tvm.relax.expr import Function +import tvm.tl as tl +import bitblas +from bitblas.ops.base_scheduler import BaseScheduler +from bitblas.base.arch import CUDA +from bitblas.base import Hint +from bitblas.base.utils import get_dummy_input_arrays +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +import tempfile +import itertools +from tvm.ir.supply import GlobalVarSupply +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils.tensor_adapter import ( + np_float2np_bf16,) +import logging + +logger = logging.getLogger(__name__) + + +def get_rasterization_code(pannel_width: int = 8) -> str: + return f""" + const int MAX_BLOCK_N = {pannel_width}; + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + """ + + +class CompileResult: + """ + Class to store the result of compilation + """ + + def __init__(self, config, sch, mod: Module): + self.config = config + self.sch = sch + self.mod = mod + self.code = mod.imported_modules[0].get_source() if mod else None + self.latency = 1e9 + self.time_evaluator = None + + def profile(self, data_distribution="uniform"): + func = self.sch.mod["main"] + device = self.config.arch.device + profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution) + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + return latency + + +def _apply_config( + scheduler: BaseScheduler, + config: Dict = None, +) -> Optional[IRModule]: + """ + find rules: + case 1. if the main block has no reduce op, then use the Elementwise rule. + case 2. if the config enabled tensorcore, then use the TensorCore rule. + case 3. if any([t > 1 for t in config.reduce_thread]), we should use the InnerThread Reduction Rule. + case 4. else we should use general reduction rule. + """ + logger.debug("Scheduler Apply config {}".format(config)) + scheduled_func = scheduler.apply_config(**config) + if scheduled_func is None: + return None + else: + return tvm.IRModule.from_expr(scheduled_func) + + +def apply_and_build_parallel(scheduler, + configs, + arch, + num_repeats=3, + max_workers=10, + timeout=30, + data_distribution="uniform") -> CompileResult: + cpresults = [] + + max_workers = min(len(configs), os.cpu_count(), max_workers) + + # apply config in thread parallel + _scheduled_ir_modules: List[Schedule] = [] + + def _submit_config(f, c): + try: + scheduled_ir_module = _apply_config(f, c) + except Exception as apply_schedule_error: + logger.debug("Apply schedule failed: {}".format(apply_schedule_error)) + scheduled_ir_module = None + return scheduled_ir_module + + with ThreadPoolExecutor(max_workers=max_workers) as _scheduler: + futures = {_scheduler.submit(_submit_config, scheduler, config) for config in configs} + for future in as_completed(futures, timeout=timeout): + _scheduled_ir_modules.append(future.result()) + + builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout) + + # build in process parallel + def _build(context) -> str: + idx, mod, arch = context + if mod is None: + return idx, None, None + + config = configs[idx] + + @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) + def tvm_callback_cuda_postproc(code, _): + code = tensor_replace_dp4a(code) + code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) + return code + # check only have one function in the module + if len(mod.functions) > 1: + raise ValueError("Only support one function in the module") + tl_prim_func = list(mod.functions.values())[0] + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + "tir.disable_cse_tir": True, + }): + rt_mod = tl.lower(tl_prim_func, arch.target, runtime_only=True) + + from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel + + artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) + code = rt_mod.imported_modules[0].get_source() + rt_mod.export_library(artifact_path, fcompile=tar) + return idx, code, artifact_path + + _mods = [mod for mod in _scheduled_ir_modules] + + for map_result in builder.map_with_error_catching( + _build, + [(i, mod, arch) for i, mod in enumerate(_mods)], + ): + if map_result.status == StatusKind.TIMEOUT: + logger.debug("LocalBuilder: Timeout") + elif map_result.status == StatusKind.EXCEPTION: + # TODO(lei): redirect the exception to file if needed + logger.debug("LocalBuilder: An exception occurred {}".format(map_result.value)) + continue + elif map_result.status == StatusKind.COMPLETE: + idx, code, artifact_path = map_result.value + ir_module = _scheduled_ir_modules[idx] + sch = tvm.tir.Schedule(ir_module) + config = configs[idx] + if artifact_path is None: + ARTIFACT_NOT_FOUND = f"Apply config {config} failed, artifact path is None" + logger.debug(ARTIFACT_NOT_FOUND) + continue + rt_mod = tvm.runtime.load_module(artifact_path) + # Transform Tuning Config to Hint + hint = Hint.from_dict( + { + **{"arch": arch}, + **config, + } + ) + cpresult = CompileResult(hint, sch, rt_mod) + timer_cuda_mod = rt_mod.time_evaluator( + rt_mod.entry_name, arch.device, number=num_repeats) + cpresult.time_evaluator = timer_cuda_mod + cpresult.code = code + cpresults.append(cpresult) + else: + raise ValueError(f"Unreachable: unexpected result: {map_result}") + + del builder + + best = None + best_latency = 1e9 + for cpresult in cpresults: + config = cpresult.config + try: + latency = cpresult.profile(data_distribution=data_distribution) + except Exception as e_mesg: + logger.debug(f"Evaluation with config failed {e_mesg}") + continue + logger.info("Evaluation with config {}".format(config)) + logger.info("Time cost of this config: {:.3f} ms".format(latency)) + + cpresult.latency = latency + if latency < best_latency: + best_latency = latency + best = cpresult + + return cpresults, best + + +def apply_and_build( + scheduler, + configs, + arch, + parallel_build=False, + data_distribution="uniform", +) -> Tuple[List[CompileResult], CompileResult]: + max_workers = 10 if parallel_build else 1 + return apply_and_build_parallel( + scheduler, configs, arch, max_workers=max_workers, data_distribution=data_distribution) + + +def fast_tune( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, + data_distribution: Literal["uniform", "onefill"] = "uniform", +): + # check the function is a primfunc + if not isinstance(func, tir.PrimFunc): + raise ValueError("Only support func is PrimFunc") # pragma: no cover + + if target.kind.name != "cuda": + logger.error("Only support CUDA target") + return None, None + + specilized_func = func + if func.attrs is not None and "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + # should be int value + if not all([isinstance(v.value, int) for v in opt_shapes.values()]): + logger.error("The opt_shapes should be int value") + return None, None + # currently only support one dynamic range + if len(opt_shapes) > 1: + logger.error("Currently only support one dynamic range") + return None, None + + for buffer in func.buffer_map.values(): + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis.name not in opt_shapes: + raise NotImplementedError( + "Currently do not support fast tune with none-dynamic range set") + if opt_shapes: + for name, shape in opt_shapes.items(): + var = find_var_from_func(func, name) + specilized_func = func.specialize({ + var: shape.astype(var.dtype) + }).with_attr("is_specialized") + + arch = CUDA(target) + + policy = DefaultPolicy(func=func, arch=arch) + try: + specilized_func, tags = get_tensorized_func_and_tags(specilized_func, arch.target) + except Exception as e_msg: + logger.debug("Get tensorized func and tags failed: ", e_msg) + tags = None + if tags: + policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) + + configs = policy.emit_config(topk) + + if len(configs) == 0: + raise ValueError("No valid config generated") + + cpresults, best = apply_and_build( + func, + configs, + arch, + parallel_build=parallel_build, + data_distribution=data_distribution, + ) + + return cpresults, best + diff --git a/testing/python/operators/test_general_matmul_ops_backend_tl.py b/testing/python/operators/test_general_matmul_ops_backend_tl.py index 90ed00c6e..eccb8ebb3 100644 --- a/testing/python/operators/test_general_matmul_ops_backend_tl.py +++ b/testing/python/operators/test_general_matmul_ops_backend_tl.py @@ -38,11 +38,47 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la assert get_codegen_result(matmul) +def matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + matmul.hardware_aware_finetune(topk=10) + assert get_codegen_result(matmul) + + def test_matmul_codegen_default(): matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), + # FP32 Accum + matmul_codegen_default(768, 768, 768, "float16", "float16", "float32", "float16", "nt", False, + -1, False, False, None), + # INT32 Accum + matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, + False, None), + + +def test_matmul_finetune(): + matmul_finetune(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, + False, False, None), # fmt: on