diff --git a/README.md b/README.md index bf1b7b04e..24da9f8f9 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Some of the key features of BitBLAS include: - Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script). ## Latest News - +- 11/04/2024 🚀🚀: We've supported high performance A INT4 x W INT4/INT2 Matmul. - 10/02/2024 🚀🚀: We've added initial Flash Attention Ops and its implementation in Tilelang! Please refer to [PythonAPI](https://github.com/microsoft/BitBLAS/blob/main/docs/PythonAPI.md) and [QuickStart](https://github.com/microsoft/BitBLAS/blob/main/docs/QuickStart.md) docs and [PR #202](https://github.com/microsoft/BitBLAS/pull/202). - 08/12/2024 🚀🚀: We've improved performance for contiguous batching. To enable it, you'll need to set specific flags. For more details, please refer to [PR #133](https://github.com/microsoft/BitBLAS/pull/133). - 07/11/2024 ✨: Ladder is published and presented in OSDI'24. Please find [Ladder paper and presentation](https://www.usenix.org/conference/osdi24/presentation/wang-lei) if you are interested in the technical details of BitBLAS. @@ -84,6 +84,8 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and | INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | | FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | +| INT4 | INT4 | INT32 | FP32/FP16 | **√** | RTX 4090(SM_89) | +| INT4 | INT4 | INT32 | FP32/FP16 | **√** | RTX 4090(SM_89) | We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR. diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 75f4b1757..0ef6d2df3 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1001,7 +1001,7 @@ { // TODO(lei): uint4 sub should be enhanced. // 0x03 0x03 0x03 0x03 - i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; + // i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; } } } @@ -1625,7 +1625,7 @@ def initialize_tensor_intrin(): def get_lop3_intrin_group( - out_dtype: Literal["float16", "int8"], + out_dtype: Literal["float16", "int8", "int4"], source_format: Literal["int", "uint"] = "uint", source_bit: int = 4, storage_dtype: Literal["int32", "int8"] = "int8", @@ -1644,8 +1644,8 @@ def get_lop3_intrin_group( in_dtype : Literal["int8"] The data type of the input. It should be "int8". - out_dtype : Literal["float16", "int8"] - The data type of the output. It can be either "float16" or "int8". + out_dtype : Literal["float16", "int8", "int4"] + The data type of the output. It can be either "float16" or "int8" or "int4". storage_nbit : int, optional The number of bits used for storage. By default, it is 4. @@ -1667,10 +1667,11 @@ def get_lop3_intrin_group( Dict[str, str] A dictionary mapping the names of the intrinsics to their corresponding implementations. """ - assert out_dtype in ["float16", - "int8"], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8'.") + assert out_dtype in [ + "float16", "int8", "int4" + ], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .") - dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"} + dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"} target_dtype = dtype_mapping[out_dtype] target_bits = tvm.DataType(out_dtype).bits loop_extent = 128 // target_bits @@ -1707,6 +1708,7 @@ def get_lop3_intrin_group( "i1_to_i8": decode_i1s_to_i8s, "i2_to_i8": decode_i2s_to_i8s, "i4_to_i8": decode_i4s_to_i8s, + "i2_to_i4": decode_i2s_to_i4s, } key = f"i{source_bit}_to_{target_dtype}" if with_scaling: @@ -1722,6 +1724,8 @@ def get_lop3_intrin_group( d4f = "f16" elif out_dtype == "int8": d4f = "i8s" + elif out_dtype == "int4": + d4f = "i4s" else: raise ValueError("Unsupported target dtype: {}".format(target_dtype)) source_symbol = "u" if source_format == "uint" else "s" diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index e71b18971..327cd1a3d 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm +from tvm import DataType from tvm.target import Target import operator from functools import reduce @@ -35,6 +36,9 @@ ("float16", "float16"), ("bfloat16", "bfloat16"), ("int8", "int8"), + ("uint8", "uint8"), + ("int4", "int4"), + ("uint4", "uint4"), ("e4m3_float8", "e4m3_float8"), ("e4m3_float8", "e5m2_float8"), ("e5m2_float8", "e4m3_float8"), @@ -142,6 +146,11 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], if self.A_dtype in ["e4m3_float8", "e5m2_float8", "bfloat16"]: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) object.__setattr__(self, "propagate_b", TransformKind.NonTransform) + if self.A_dtype in ["int4", "uint4"]: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + # TODO(lei): tl doesn't implement IntraWarpTransform + if self.propagate_b == TransformKind.IntraWarpTransform: + object.__setattr__(self, "propagate_b", TransformKind.LDMatrixTransform) # TODO(lei): propagation can only be enabled on SM80+ Devices and MI200+ # We should add a check here to disable the propagation if the device is not supported. @@ -358,6 +367,10 @@ def __init__( self.source_format = source_format self.bit = bit + + # This is a hack to support the int4 and uint4 + if config.A_dtype in ["int4", "uint4"]: + backend = "tl" super().__init__(name, config, target, backend) if source_format == "int" and self.with_zeros: @@ -471,10 +484,13 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): # weight transform should be done in the unpacked level # otherwise the bit trick should be applied and that is # too complex to be implemented in the ladder permutation. + datatype = self.A_dtype + if DataType(datatype).bits < 8: + datatype = self.storage_dtype ladder_permutate_config = LadderPermutateConfig( M=self.N, N=self.K, - datatype=self.A_dtype, + datatype=datatype, dequantize_bits=-1, storage_dtype=self.storage_dtype, propagate_kind="B", diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index fe603be51..d3c9b38aa 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -17,6 +17,11 @@ MatmulWeightPropagationScheduler, # noqa: F401 ) +from .matmul_tensorcore_s4 import ( + MatmulINT4FineGrainScheduler, # noqa: F401 + MatmulINT4WeightPropagationScheduler, # noqa: F401 +) + from bitblas.ops.common import TransformKind from typing import Union @@ -82,8 +87,13 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag conditions.append(propagate_b == TransformKind.LDMatrixTransform) return all(conditions) + def is_int4_dtype(dtype): + return dtype == "int4" or dtype == "uint4" + if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): - return MatmulWeightPropagationScheduler( + Scheduler = MatmulWeightPropagationScheduler if not is_int4_dtype( + in_dtype) else MatmulINT4WeightPropagationScheduler + return Scheduler( M=M, N=N, K=K, @@ -94,7 +104,9 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag accum_dtype=accum_dtype, ) if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): - return MatmulFineGrainScheduler( + Scheduler = MatmulFineGrainScheduler if not is_int4_dtype( + in_dtype) else MatmulINT4FineGrainScheduler + return Scheduler( M=M, N=N, K=K, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 13658aab4..7833865b8 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -368,6 +368,9 @@ def with_default_config(self): warp_row_tiles = getattr(self, "warp_row_tiles", 32) warp_col_tiles = getattr(self, "warp_col_tiles", 32) chunk = getattr(self, "chunk", 32) + # Swizzle size for INT8 Storage is 64 + if DataType(self.in_dtype).bits <= 8: + chunk = 64 num_stages = getattr(self, "num_stages", 2) enable_rasterization = getattr(self, "enable_rasterization", False) @@ -597,7 +600,9 @@ def apply_config( threads = warp_size * (block_row_warps * block_col_warps) # Calculate local fragment sizes for tensor core - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -633,9 +638,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) # Thread-level parallelism for Tensor Cores thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -809,7 +814,9 @@ def matmul_macro_tensorcore( warp_size = 32 # nvidia gpu warp size is 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -838,9 +845,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") T.annotate_layout({ @@ -947,7 +954,9 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( warp_size = 32 # nvidia gpu warp size is 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -977,9 +986,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") T.annotate_layout({ diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py new file mode 100644 index 000000000..efdfd58ea --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py @@ -0,0 +1,409 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# INT4 Tensor Core Implementation for NVIDIA GPUs +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional, List +from bitblas.tl.utils import ( + get_mma_micro_size, + make_swizzle_layout, +) +from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( + MatmulFineGrainScheduler, + MatmulWeightPropagationScheduler, +) +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter, + INT4TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.base.utils import get_roller_hints_from_func +from dataclasses import dataclass +from bitblas.ops.general_matmul.tirscript import (matmul_select_implementation) + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + K = self.K // 2 # 2xint4 should be packed into one single int8 + # Simple TIR Compute Expression + storage_dtype = "int8" + ir_module = matmul_select_implementation( + M=self.M, + N=self.N, + K=K, + in_dtype=storage_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + ) + + roller_hints = get_roller_hints_from_func( + ir_module["main"], + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + def serialze_hints_to_configs(hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + return serialze_hints_to_configs(roller_hints) + + def apply_config( + self, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, + enable_rasterization=False, + ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" + + M, N, K = self.M, self.N, self.K + K = K // 2 # 2xint4 should be packed into one single int8 + trans_A, trans_B = self.trans_A, self.trans_B + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + assert in_dtype == "int4", "Only support int4 input" + assert accum_dtype == "int32", "Only support int32 accumulation" + storage_dtype = "int8" + + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(storage_dtype) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + # Define the shapes of matrices and shared memory buffers + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + threads = warp_size * (block_row_warps * block_col_warps) + + # Calculate local fragment sizes for tensor core + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" + + # Configure the tensor core intrinsic emitter + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=storage_dtype, + b_dtype=storage_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + # Define the main kernel using the generated configuration + @T.prim_func + def main( + A: T.Buffer(A_shape, storage_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + # Grid and thread configuration for CUDA kernel + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + # Allocate shared memory and local fragments + A_shared = T.alloc_shared(A_shared_shape, storage_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), storage_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), storage_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + # Thread-level parallelism for Tensor Cores + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + # Apply memory layout optimizations + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Optional rasterization for L2 locality enhancement + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + # Initialize accumulation buffer to zero + T.clear(C_local) + + # Main matrix multiplication pipeline with multiple stages + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + # Load A matrix into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B matrix into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_local, B_local, C_local) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(main) + + def __post_init__(self): + # Validate the matrix transpose settings + assert self.trans_A is False, "Currently only support Matrix A not transposed" + assert self.trans_B is True, "Currently only support Matrix B transposed" + + return + + +@dataclass +class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): + + def apply_config( + self, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=32, + warp_col_tiles=32, + chunk=16, + num_stages=2, + enable_rasterization=False, + ): + + M, N, K = self.M, self.N, self.K + K = K // 2 # 2xint4 should be packed into one single int8 + trans_A, trans_B = self.trans_A, self.trans_B + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + assert in_dtype == "int4", "Only support int4 input" + assert accum_dtype == "int32", "Only support int32 accumulation" + storage_dtype = "int8" + + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(storage_dtype) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + # TODO(lei): Can be generalized to analyzed from bank size + pad_factor = 8 if storage_dtype == "float16" else 16 + + can_swizzle_a = block_K * DataType(storage_dtype).bits == 512 + apply_pad_a = not can_swizzle_a + + # Define the shapes of matrices and shared memory buffers + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + # GPU warp configuration for NVIDIA GPUs + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + + # Calculate local fragment sizes for tensor core + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" + + # Configure the tensor core intrinsic emitter + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=storage_dtype, + b_dtype=storage_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=self.weight_transform_kind, + ) + + # Define the main kernel using the generated configuration + @T.prim_func + def main( + A: T.Buffer(A_shape, storage_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + # Grid and thread configuration for CUDA kernel + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + # Allocate shared memory and local fragments + A_shared = T.alloc_shared(A_shared_shape, storage_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), storage_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), storage_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + # Thread-level parallelism for Tensor Cores + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + # Apply memory layout optimizations + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + # B_shared: make_swizzle_layout(B_shared), + }) + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + # Initialize accumulation buffer to zero + T.clear(C_local) + + # Main matrix multiplication pipeline with multiple stages + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + # Load A matrix into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B matrix into shared memory + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ): + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_local, B_local, C_local) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(main) + + def __post_init__(self): + # Validate the matrix transpose settings + assert self.trans_A is False, "Currently only support Matrix A not transposed" + assert self.trans_B is True, "Currently only support Matrix B transposed" + + return diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index 9fe99512c..f4943bfe0 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -13,6 +13,14 @@ MatmulDequantizeWeightPropagationScheduler, # noqa: F401 ) +from .finegrained_primitive_tensorcore_s4 import ( + MatmulINT4DequantizeFineGrainedScheduler, # noqa: F401 +) + +from .ladder_weight_transform_tensorcore_s4 import ( + MatmulINT4DequantizeWeightPropagationScheduler, # noqa: F401 +) + from bitblas.ops.common import TransformKind from typing import Union @@ -86,6 +94,53 @@ def can_apply_block_scheduler(propagate_a, propagate_b): conditions.append(propagate_b == TransformKind.NonTransform) return all(conditions) + def is_int4_dtype(dtype): + return dtype == "int4" or dtype == "uint4" + + if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): + Scheduler = MatmulDequantizeWeightPropagationScheduler if not is_int4_dtype( + in_dtype) else MatmulINT4DequantizeWeightPropagationScheduler + return Scheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + zeros_mode=zeros_mode, + ) + if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): + Scheduler = MatmulDequantizeFineGrainedScheduler if not is_int4_dtype( + in_dtype) else MatmulINT4DequantizeFineGrainedScheduler + return Scheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + zeros_mode=zeros_mode, + ) if can_apply_block_scheduler(propagate_a, propagate_b): return MatmulDequantizeScheduler( M=M, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py index 5f1a8f5ed..4d45e1204 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -233,7 +233,7 @@ def apply_config( func_name: str = "" if fast_decoding is True: lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, + out_dtype=in_dtype, source_format=source_format, source_bit=num_bits, storage_dtype=storage_dtype, @@ -297,12 +297,9 @@ def general_dequant_matmul( Qzeros, func_name, by, - tx, k, - i, block_N, block_K, - threads, ) else: self._normal_dequant( diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index d57951455..942a66a90 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -182,6 +182,9 @@ def with_default_config(self): warp_row_tiles = getattr(self, "warp_row_tiles", 32) warp_col_tiles = getattr(self, "warp_col_tiles", 32) chunk = getattr(self, "chunk", 32) + if DataType(self.in_dtype).bits <= 8: + chunk = 64 + num_stages = getattr(self, "num_stages", 2) enable_rasterization = getattr(self, "enable_rasterization", False) @@ -275,7 +278,7 @@ def apply_config( func_name: str = "" if fast_decoding is True: lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, + out_dtype=in_dtype, source_format=source_format, source_bit=num_bits, storage_dtype=storage_dtype, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py new file mode 100644 index 000000000..e7fb80d24 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py @@ -0,0 +1,299 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional, List +from bitblas.tl.utils import ( + get_mma_micro_size, # noqa: F401 + make_swizzle_layout, # noqa: F401 + index_to_coordinates, # noqa: F401 +) + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter, # noqa: F401 +) +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.base.utils import get_roller_hints_from_func +from dataclasses import dataclass +from bitblas.ops.general_matmul.tirscript import ( + matmul_dequantize_select_implementation,) +from bitblas.ops.general_matmul.tilelang.dequantize.finegrained_primitive_tensorcore import ( + MatmulDequantizeFineGrainedScheduler,) +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedScheduler): + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + K = self.K // 2 # 2xint4 should be packed into one single int8 + storage_dtype = "int8" + num_bits = self.num_bits * 2 + # INT4XINT2 is equal to int8xint4 with reduced shape + # Simple TIR Compute Expression + ir_module = matmul_dequantize_select_implementation( + M=self.M, + N=self.N, + K=K, + in_dtype=storage_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + bit=num_bits, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + zeros_mode=self.zeros_mode) + + roller_hints = get_roller_hints_from_func( + ir_module["main"], + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + def serialze_hints_to_configs(hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + return serialze_hints_to_configs(roller_hints) + + def apply_config( + self, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, + enable_rasterization=False, + ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" + + M, N, K = self.M, self.N, self.K + K = K // 2 # 2xint4 should be packed into one single int8 + + trans_A, trans_B = self.trans_A, self.trans_B + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + + assert in_dtype == "int4", "Only support int4 input" + assert accum_dtype == "int32", "Only support int32 accumulation" + storage_dtype = self.storage_dtype + + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(storage_dtype) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = warp_size * (block_row_warps * block_col_warps) + + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + fast_decoding = self.fast_decoding + + num_bits = self.num_bits + source_format = self.source_format + num_elems_per_byte = self.num_elems_per_byte + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(storage_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + group_size = self.group_size + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = self.common_header + import_source + + # Configure the tensor core intrinsic emitter + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=storage_dtype, + b_dtype=storage_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def general_dequant_matmul( + A: T.Buffer(A_shape, storage_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, storage_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, storage_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + + A_frag = T.alloc_local((warp_rows * fragement_size_a), storage_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), storage_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) + + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], storage_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + }) + + T.use_swizzle(10, enable=enable_rasterization) + + T.import_source(import_source) + + T.clear(C_frag) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = ( + i * threads * local_size_compressed + tx * local_size_compressed + + v) + vi, vj = index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj] + + if fast_decoding: + T.call_extern('handle', func_name, T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), 32) + else: + for v in T.serial(0, local_size): + int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F + + int4_0 = (int2x2_value >> 0) & 0x03 + int4_1 = (int2x2_value >> 2) & 0x03 + + B_dequantize_local[v] = (int4_1 << 4) | int4_0 + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi, vj = index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=tx, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=tx, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(general_dequant_matmul) + + @property + def num_elems_per_byte(self): + # force value for int4 + storage_nbit = 4 + num_bits = self.num_bits + return storage_nbit // num_bits + + def __post_init__(self): + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index 7f8920575..4652566c6 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -124,7 +124,7 @@ def apply_config( func_name: str = "" if fast_decoding is True: lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, + out_dtype=in_dtype, source_format=source_format, source_bit=num_bits, storage_dtype=storage_dtype, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py new file mode 100644 index 000000000..153e1f64a --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py @@ -0,0 +1,282 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional +from bitblas.tl.utils import ( + get_mma_micro_size, # noqa: F401 + make_swizzle_layout, # noqa: F401 + index_to_coordinates, # noqa: F401 +) +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 +) +from bitblas.ops.common import TransformKind # noqa: F401 +from dataclasses import dataclass +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group +from bitblas.ops.general_matmul.tilelang.dequantize.ladder_weight_transform_tensorcore import ( + MatmulDequantizeWeightPropagationScheduler,) + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): + + def apply_config( + self, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, + enable_rasterization=False, + ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" + + M, N, K = self.M, self.N, self.K + K = K // 2 # 2xint4 should be packed into one single int8 + + trans_A, trans_B = self.trans_A, self.trans_B + weight_transform_kind = self.weight_transform_kind + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + assert (weight_transform_kind == TransformKind.LDMatrixTransform + ), "Dequantize only implement for LDMatrixTransform currently" + + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + assert in_dtype == "int4", "Only support int4 input" + assert accum_dtype == "int32", "Only support int32 accumulation" + storage_dtype = self.storage_dtype + + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(storage_dtype) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = warp_size * (block_row_warps * block_col_warps) + + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + fast_decoding = self.fast_decoding + + num_bits = self.num_bits + source_format = self.source_format + num_elems_per_byte = self.num_elems_per_byte + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(storage_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + group_size = self.group_size + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = ( + N // micro_size_y, + K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + B_dequantize_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=in_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + storage_scope="warp", # to get the ladder transform lop3 intrin + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = self.common_header + import_source + + # Configure the tensor core intrinsic emitter with ladder transform + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=storage_dtype, + b_dtype=storage_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=weight_transform_kind, + ) + + vec_load_qb = 16 + if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * block_K // num_elems_per_byte // threads + + @T.prim_func + def general_dequant_matmul( + A: T.Buffer(A_shape, storage_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, storage_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, storage_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + + A_frag = T.alloc_local((warp_rows * fragement_size_a), storage_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), storage_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], storage_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + T.use_swizzle(10, enable=enable_rasterization) + + T.import_source(import_source) + + T.clear(C_frag) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Load B into shared memory + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + idx = i * threads * vec_load_qb + threads * vec_load_qb + tx * vec_load_qb + v + vj, vk, vjj, vkk = index_to_coordinates(idx, B_shared_shape) + B_shared[vj, vk, vjj, + vkk] = B[bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = ( + i * threads * local_size_compressed + tx * local_size_compressed + + v) + vi, vj, vii, vjj = index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj, vii, vjj] + + if fast_decoding: + # Simulated dequantization + T.call_extern('handle', func_name, T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), 32) + else: + for v in T.serial(0, local_size): + int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F + + int4_0 = (int2x2_value >> 0) & 0x03 + int4_1 = (int2x2_value >> 2) & 0x03 + + B_dequantize_local[v] = (int4_1 << 4) | int4_0 + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi, vj, vii, vjj = index_to_coordinates(index, + B_dequantize_shared_shape) + B_dequantize_shared[vi, vj, vii, vjj] = B_dequantize_local[v] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=tx, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=tx, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(general_dequant_matmul) + + @property + def num_elems_per_byte(self): + # force value for int4 + storage_nbit = 4 + num_bits = self.num_bits + return storage_nbit // num_bits + + def __post_init__(self): + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index e2113fa15..1b22491f5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -248,6 +248,7 @@ def tvm_callback_cuda_postproc(code, _): def scheduler_with_default(self, scheduler: BaseScheduler): scheduled_ir_module = IRModule.from_expr(scheduler.with_default_config()) if scheduled_ir_module is not None: + self.ir_module = scheduled_ir_module return scheduled_ir_module return None diff --git a/install.sh b/install.sh index 99fd89b8c..49d1fa815 100755 --- a/install.sh +++ b/install.sh @@ -3,14 +3,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# install requirements +echo "Starting installation script..." + +# Step 1: Install Python requirements +echo "Installing Python requirements from requirements.txt..." pip install -r requirements.txt +if [ $? -ne 0 ]; then + echo "Error: Failed to install Python requirements." + exit 1 +else + echo "Python requirements installed successfully." +fi -# install llvm +# Step 2: Define LLVM version and architecture LLVM_VERSION="10.0.1" IS_AARCH64=false EXTRACT_PATH="3rdparty" +echo "LLVM version set to ${LLVM_VERSION}." +echo "Is AARCH64 architecture: $IS_AARCH64" +# Step 3: Determine the correct Ubuntu version based on LLVM version UBUNTU_VERSION="16.04" if [[ "$LLVM_VERSION" > "17.0.0" ]]; then UBUNTU_VERSION="22.04" @@ -19,7 +31,9 @@ elif [[ "$LLVM_VERSION" > "16.0.0" ]]; then elif [[ "$LLVM_VERSION" > "13.0.0" ]]; then UBUNTU_VERSION="18.04" fi +echo "Ubuntu version for LLVM set to ${UBUNTU_VERSION}." +# Step 4: Set download URL and file name for LLVM BASE_URL="https://github.com/llvm/llvm-project/releases/download/llvmorg-${LLVM_VERSION}" if $IS_AARCH64; then FILE_NAME="clang+llvm-${LLVM_VERSION}-aarch64-linux-gnu.tar.xz" @@ -27,45 +41,100 @@ else FILE_NAME="clang+llvm-${LLVM_VERSION}-x86_64-linux-gnu-ubuntu-${UBUNTU_VERSION}.tar.xz" fi DOWNLOAD_URL="${BASE_URL}/${FILE_NAME}" +echo "Download URL for LLVM: ${DOWNLOAD_URL}" +# Step 5: Create extraction directory +echo "Creating extraction directory at ${EXTRACT_PATH}..." mkdir -p "$EXTRACT_PATH" +if [ $? -ne 0 ]; then + echo "Error: Failed to create extraction directory." + exit 1 +else + echo "Extraction directory created successfully." +fi -echo "Downloading $FILE_NAME from $DOWNLOAD_URL" +# Step 6: Download LLVM +echo "Downloading $FILE_NAME from $DOWNLOAD_URL..." curl -L -o "${EXTRACT_PATH}/${FILE_NAME}" "$DOWNLOAD_URL" - if [ $? -ne 0 ]; then - echo "Download failed!" + echo "Error: Download failed!" exit 1 +else + echo "Download completed successfully." fi -echo "Extracting $FILE_NAME to $EXTRACT_PATH" +# Step 7: Extract LLVM +echo "Extracting $FILE_NAME to $EXTRACT_PATH..." tar -xJf "${EXTRACT_PATH}/${FILE_NAME}" -C "$EXTRACT_PATH" - if [ $? -ne 0 ]; then - echo "Extraction failed!" + echo "Error: Extraction failed!" exit 1 +else + echo "Extraction completed successfully." fi -echo "Download and extraction completed successfully." - +# Step 8: Determine LLVM config path LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)" -echo "LLVM config path: $LLVM_CONFIG_PATH" +echo "LLVM config path determined as: $LLVM_CONFIG_PATH" -# clone and build tvm +# Step 9: Clone and build TVM +echo "Cloning TVM repository and initializing submodules..." git submodule update --init --recursive +if [ $? -ne 0 ]; then + echo "Error: Failed to initialize submodules." + exit 1 +else + echo "Submodules initialized successfully." +fi +# Step 10: Build TVM +echo "Starting TVM build process..." cd 3rdparty/tvm if [ -d build ]; then + echo "Existing build directory found. Removing it..." rm -rf build fi +echo "Creating new build directory for TVM..." mkdir build cp cmake/config.cmake build cd build -echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake && echo "set(USE_CUDA /usr/local/cuda)" >> config.cmake -cmake .. && make -j && cd ../../.. +echo "Configuring TVM build with LLVM and CUDA paths..." +echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake +echo "set(USE_CUDA /usr/local/cuda)" >> config.cmake + +echo "Running CMake for TVM..." +cmake .. +if [ $? -ne 0 ]; then + echo "Error: CMake configuration failed." + exit 1 +fi +echo "Building TVM with make..." +make -j +if [ $? -ne 0 ]; then + echo "Error: TVM build failed." + exit 1 +else + echo "TVM build completed successfully." +fi + +cd ../../.. + +# Step 11: Set environment variables +echo "Configuring environment variables for TVM..." echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc + +# Step 12: Source .bashrc to apply changes +echo "Applying environment changes by sourcing .bashrc..." source ~/.bashrc +if [ $? -ne 0 ]; then + echo "Error: Failed to source .bashrc." + exit 1 +else + echo "Environment configured successfully." +fi + +echo "Installation script completed successfully." diff --git a/integration/BitNet/int4_kernel/tl_int4xint2.py b/integration/BitNet/int4_kernel/tl_int4xint2.py index 16797501a..670f72b07 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2.py @@ -53,8 +53,8 @@ def tl_matmul( # This is a debug config block_row_warps = 2 block_col_warps = 2 - warp_row_tiles = 64 - warp_col_tiles = 64 + warp_row_tiles = 32 + warp_col_tiles = 32 chunk = 32 if in_dtype == "float16" else 64 shared_scope = "shared.dyn" storage_dtype = "int8" @@ -271,5 +271,5 @@ def test_assert_tl_matmul(): # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") - # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") - assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(256, 256, 256, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py index d44717e7f..e879f1524 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py @@ -257,7 +257,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) @@ -328,4 +328,4 @@ def test_assert_tl_matmul(): # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") - assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32", False) + assert_tl_matmul_correctness(256, 256, 256, "int8", "int32", "int32", False) diff --git a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py index 1603698b2..b0e0c4d5d 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py @@ -46,8 +46,8 @@ def tl_matmul( # This is a debug config block_row_warps = 2 block_col_warps = 2 - warp_row_tiles = 64 - warp_col_tiles = 64 + warp_row_tiles = 32 + warp_col_tiles = 32 chunk = 32 if in_dtype == "float16" else 64 shared_scope = "shared.dyn" @@ -185,6 +185,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + print(matmul) mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -219,7 +220,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -235,4 +236,4 @@ def test_assert_tl_matmul(): # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") - assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(256, 256, 256, "int8", "int32", "int32") diff --git a/testing/python/operators/test_general_matmul_ops_int4.py b/testing/python/operators/test_general_matmul_ops_int4.py new file mode 100644 index 000000000..667a7aca8 --- /dev/null +++ b/testing/python/operators/test_general_matmul_ops_int4.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +def matmul_int4_torch_forward(M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + layout, + propagate_b, + fast_decoding=False): + import torch + matmul_config = bitblas.MatmulConfig( + M=M, # M dimension + N=N, # N dimension + K=K, # K dimension + A_dtype=A_dtype, # activation A dtype + W_dtype=W_dtype, # weight W dtype + accum_dtype=accum_dtype, # accumulation dtype + out_dtype=out_dtype, # output dtype + layout=layout, # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose + propagate_b=propagate_b, # propagate B matrix + fast_decoding=fast_decoding, + ) + + matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False) + + # if finetuning is needed, uncomment the following line + # matmul.hardware_aware_finetune(topk=20) + + print(matmul.get_source()) + storage_dtype = "int8" + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + if W_dtype == "int4": + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + if propagate_b: + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + LB = ladder_permutate(compressed_B.cpu()).cuda() + matmul(compressed_A, LB, output=C) + else: + matmul(compressed_A, compressed_B, output=C) + elif W_dtype == "int2": + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ( + (B[:, 2::4] & 0x03) << 4) + ((B[:, 3::4] & 0x03) << 6) + if propagate_b: + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target="llvm", + ) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() + ladder_shape = compressed_B_ladder.shape + int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) + int2_tensor = torch.zeros(int2_shape, device="cuda", dtype=torch.int8) + for i in range(int2_tensor.shape[-1]): + int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ( + (compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ( + (compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ( + (compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) + + raw_tensor_shape = int2_tensor.shape + print(f"{raw_tensor_shape=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(int2_tensor.cpu()).cuda() + lop3_compressed_B = lop3_compressed_B.view(raw_tensor_shape) + else: + lop3_compressed_B = int2_tensor + matmul(compressed_A, lop3_compressed_B, output=C) + else: + if fast_decoding: + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target="llvm", + ) + lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() + matmul(compressed_A, lop3_compressed_B, output=C) + else: + matmul(compressed_A, compressed_B, output=C) + + print(C) + latency = matmul.profile_latency() + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_matmul_torch_forward(): + matmul_int4_torch_forward(128, 128, 128, "int4", "int4", "int32", "int32", "nt", False) + matmul_int4_torch_forward(128, 128, 128, "int4", "int4", "int32", "int32", "nt", True) + matmul_int4_torch_forward(128, 128, 128, "int4", "int2", "int32", "int32", "nt", False, False) + matmul_int4_torch_forward(128, 128, 128, "int4", "int2", "int32", "int32", "nt", False, True) + matmul_int4_torch_forward(128, 128, 128, "int4", "int2", "int32", "int32", "nt", True, False) + matmul_int4_torch_forward(128, 128, 128, "int4", "int2", "int32", "int32", "nt", True, True) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 31c3de7d1..857b22270 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -14,6 +14,13 @@ MatmulDequantizeScheduler, MatmulDequantizeFineGrainedScheduler, MatmulDequantizeWeightPropagationScheduler, + MatmulINT4DequantizeFineGrainedScheduler, + MatmulINT4DequantizeWeightPropagationScheduler, +) + +from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore_s4 import ( + MatmulINT4FineGrainScheduler, + MatmulINT4WeightPropagationScheduler, ) import torch @@ -78,14 +85,491 @@ def assert_matmul_blocked_apply_config_correctness( block_K=32, trans_A=False, trans_B=True, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization=False, +): + matmul = MatmulScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e-1) + + +def assert_matmul_fine_grained_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", +): + + matmul = MatmulFineGrainScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( + K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + mod(A, B, C) + + # Get Reference Result + ref_c = ( + torch.matmul(A, B.T).to(getattr(torch, out_dtype)) if trans_B else torch.matmul(A, B).to( + getattr(torch, out_dtype))) + + torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e-1) + + +def assert_matmul_fine_grained_apply_config_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): + + matmul = MatmulFineGrainScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + + +def assert_matmul_weight_propagation_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", +): + + matmul = MatmulWeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(B.cpu()).cuda() + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) + + +def assert_matmul_weight_propagation_apply_config_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): + + matmul = MatmulWeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(B.cpu()).cuda() + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) + + +def assert_matmul_int4_fine_grained_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", +): + + matmul = MatmulINT4FineGrainScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + storage_dtype = "int8" + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + + # Ensure that the latency is not None + assert latency is not None + + mod(compressed_A, compressed_B, C) + print(C) + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-1) + + +def assert_matmul_int4_fine_grained_apply_config_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): + + matmul = MatmulINT4FineGrainScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).apply_config( + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + storage_dtype = "int8" + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + + # Ensure that the latency is not None + assert latency is not None + + mod(compressed_A, compressed_B, C) + print(C) + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-1) + + +def assert_matmul_int4_weight_propagation_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", +): + + matmul = MatmulINT4WeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + ).with_default_config() + print(matmul) + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + storage_dtype = "int8" + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(compressed_B.cpu()).cuda() + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(compressed_A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def assert_matmul_int4_weight_propagation_apply_config__correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, num_stages=2, - threads=128, enable_rasterization=False, ): - matmul = MatmulScheduler( + + matmul = MatmulINT4WeightPropagationScheduler( M=M, N=N, K=K, @@ -95,27 +579,44 @@ def assert_matmul_blocked_apply_config_correctness( out_dtype=out_dtype, accum_dtype=accum_dtype, ).apply_config( - block_M=block_M, - block_N=block_N, - block_K=block_K, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, num_stages=num_stages, - threads=threads, enable_rasterization=enable_rasterization, ) + print(matmul) mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None - - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + storage_dtype = "int8" + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(compressed_B.cpu()).cuda() mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) - mod(A, B, C) + mod(compressed_A, LB, C) latency = mod.do_bench(mod.func, warmup=25) @@ -123,22 +624,31 @@ def assert_matmul_blocked_apply_config_correctness( assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) - torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e-1) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) -def assert_matmul_fine_grained_with_default_correctness( +def assert_matmul_fine_grained_dequant_int4_with_default_correctness( M, N, K, trans_A=False, trans_B=True, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", + bit=2, + storage_dtype="int8", + source_format="int", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", ): - - matmul = MatmulFineGrainScheduler( + matmul = MatmulINT4DequantizeFineGrainedScheduler( M=M, N=N, K=K, @@ -147,42 +657,80 @@ def assert_matmul_fine_grained_with_default_correctness( in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, ).with_default_config() mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( - K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype=in_dtype, + dequantize_bits=bit, + storage_dtype=storage_dtype, + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) - latency = mod.do_bench(mod.func, warmup=25) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( + (B[:, 3::4] & 0x03) << 6) + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + print(f"{compressed_B=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() + else: + lop3_compressed_B = compressed_B + print(f"{lop3_compressed_B=}") + mod(compressed_A, lop3_compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) # Ensure that the latency is not None assert latency is not None - mod(A, B, C) - # Get Reference Result - ref_c = ( - torch.matmul(A, B.T).to(getattr(torch, out_dtype)) if trans_B else torch.matmul(A, B).to( - getattr(torch, out_dtype))) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) - torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e-1) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) -def assert_matmul_fine_grained_apply_config_correctness( +def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( M, N, K, trans_A=False, trans_B=True, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", + bit=2, + storage_dtype="int8", + source_format="int", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", block_row_warps=1, block_col_warps=1, warp_row_tiles=16, @@ -191,8 +739,7 @@ def assert_matmul_fine_grained_apply_config_correctness( num_stages=2, enable_rasterization=False, ): - - matmul = MatmulFineGrainScheduler( + matmul = MatmulINT4DequantizeFineGrainedScheduler( M=M, N=N, K=K, @@ -201,6 +748,14 @@ def assert_matmul_fine_grained_apply_config_correctness( in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, ).apply_config( block_row_warps=block_row_warps, block_col_warps=block_col_warps, @@ -213,40 +768,70 @@ def assert_matmul_fine_grained_apply_config_correctness( mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() - # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype=in_dtype, + dequantize_bits=bit, + storage_dtype=storage_dtype, + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) - mod(A, B, C) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - latency = mod.do_bench(mod.func, warmup=25) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( + (B[:, 3::4] & 0x03) << 6) + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + print(f"{compressed_B=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() + else: + lop3_compressed_B = compressed_B + print(f"{lop3_compressed_B=}") + mod(compressed_A, lop3_compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) # Ensure that the latency is not None assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) - torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype)) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) -def assert_matmul_weight_propagation_with_default_correctness( + +def assert_matmul_weight_transform_dequant_int4_with_default_correctness( M, N, K, trans_A=False, trans_B=True, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", + bit=2, + storage_dtype="int8", + source_format="int", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", ): - - matmul = MatmulWeightPropagationScheduler( + matmul = MatmulINT4DequantizeWeightPropagationScheduler( M=M, N=N, K=K, @@ -255,54 +840,101 @@ def assert_matmul_weight_propagation_with_default_correctness( in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, ).with_default_config() - + print(matmul) mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() - # src_code is the generated cuda source assert src_code is not None + transform_b = 3 # assume ladder stage 3 transform + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( M=N, N=K, - transform_kind=3, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, transpose_matrix=True, ) ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() + ladder_shape = compressed_B_ladder.shape + int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) + int2_tensor = torch.zeros(int2_shape, device="cuda", dtype=torch.int8) + for i in range(int2_tensor.shape[-1]): + int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ( + (compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ( + (compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ( + (compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) + + raw_tensor_shape = int2_tensor.shape + print(f"{raw_tensor_shape=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(int2_tensor.cpu()).cuda() + lop3_compressed_B = lop3_compressed_B.view(raw_tensor_shape) + else: + lop3_compressed_B = int2_tensor - mod(A, LB, C) + mod(compressed_A, lop3_compressed_B, C) latency = mod.do_bench(mod.func, warmup=25) - + print(f"Latency: {latency}") # Ensure that the latency is not None assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) -def assert_matmul_weight_propagation_apply_config_correctness( +def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( M, N, K, trans_A=False, trans_B=True, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype="int4", + out_dtype="int32", + accum_dtype="int32", + bit=2, + storage_dtype="int8", + source_format="int", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", block_row_warps=1, block_col_warps=1, warp_row_tiles=16, @@ -311,8 +943,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( num_stages=2, enable_rasterization=False, ): - - matmul = MatmulWeightPropagationScheduler( + matmul = MatmulINT4DequantizeWeightPropagationScheduler( M=M, N=N, K=K, @@ -321,6 +952,14 @@ def assert_matmul_weight_propagation_apply_config_correctness( in_dtype=in_dtype, out_dtype=out_dtype, accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, ).apply_config( block_row_warps=block_row_warps, block_col_warps=block_col_warps, @@ -331,39 +970,73 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization=enable_rasterization, ) + print(matmul) mod, params = tl.lower(matmul) src_code = mod.imported_modules[0].get_source() - # src_code is the generated cuda source assert src_code is not None + transform_b = 3 # assume ladder stage 3 transform + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, storage_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 - C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( M=N, N=K, - transform_kind=3, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, transpose_matrix=True, ) ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() + ladder_shape = compressed_B_ladder.shape + int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) + int2_tensor = torch.zeros(int2_shape, device="cuda", dtype=torch.int8) + for i in range(int2_tensor.shape[-1]): + int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ( + (compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ( + (compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ( + (compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) + + raw_tensor_shape = int2_tensor.shape + print(f"{raw_tensor_shape=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(int2_tensor.cpu()).cuda() + lop3_compressed_B = lop3_compressed_B.view(raw_tensor_shape) + else: + lop3_compressed_B = int2_tensor - mod(A, LB, C) + mod(compressed_A, lop3_compressed_B, C) latency = mod.do_bench(mod.func, warmup=25) - + print(f"Latency: {latency}") # Ensure that the latency is not None assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) - torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) def assert_matmul_blocked_dequant_with_default_correctness( @@ -611,11 +1284,9 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( for j in range(K): if with_zeros: if zeros_mode == "original": - rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // - group_size] + rescale_b[i, j] = (b[i, j] - zeros) * scale[i, j // group_size] elif zeros_mode == "rescale": - rescale_b[i, j] = ( - b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + rescale_b[i, j] = (b[i, j] * scale[i, j // group_size] + zeros) else: raise NotImplementedError else: @@ -804,6 +1475,55 @@ def test_matmul_fine_grained(): assert_matmul_fine_grained_apply_config_correctness(1024, 1024, 1024, enable_rasterization=True) +def test_matmul_int4_fine_grained(): + # Default + assert_matmul_int4_fine_grained_with_default_correctness(256, 256, 256) + # Pipeline + assert_matmul_int4_fine_grained_apply_config_correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_int4_fine_grained_apply_config_correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_int4_fine_grained_apply_config_correctness( + 1024, 1024, 1024, enable_rasterization=True) + + +def test_matmul_int4_weight_propagation(): + # Default + assert_matmul_int4_weight_propagation_with_default_correctness(256, 256, 256) + # Pipeline + assert_matmul_int4_weight_propagation_apply_config__correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_int4_weight_propagation_apply_config__correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_int4_weight_propagation_apply_config__correctness( + 1024, 1024, 1024, enable_rasterization=True) + + +def test_matmul_int4xint2_fine_grained(): + # Default + assert_matmul_fine_grained_dequant_int4_with_default_correctness(256, 256, 256) + assert_matmul_fine_grained_dequant_int4_with_default_correctness( + 256, 256, 256, fast_decoding=True) + # Pipeline + assert_matmul_fine_grained_dequant_int4_apply_config_correctness(1024, 1024, 1024, num_stages=2) + # L2 Cache + assert_matmul_fine_grained_dequant_int4_apply_config_correctness( + 1024, 1024, 1024, enable_rasterization=True) + + +def test_matmul_int4_weight_transform_dequant(): + # Default + assert_matmul_weight_transform_dequant_int4_with_default_correctness(256, 256, 256) + assert_matmul_weight_transform_dequant_int4_with_default_correctness( + 256, 256, 256, fast_decoding=True) + # Pipeline + assert_matmul_weight_transform_dequant_int4_apply_config_correctness( + 1024, 1024, 1024, num_stages=2) + assert_matmul_weight_transform_dequant_int4_apply_config_correctness( + 1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_weight_transform_dequant_int4_apply_config_correctness( + 1024, 1024, 1024, enable_rasterization=True) + + def test_matmul_weight_propagation(): # Default assert_matmul_weight_propagation_with_default_correctness(1024, 1024, 1024) @@ -833,25 +1553,6 @@ def test_matmul_blocked_dequant_with_default(): ) assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) - assert_matmul_blocked_dequant_with_default_correctness( - 1024, - 1024, - 1024, - source_format="uint", - bit=4, - with_scaling=True, - fast_decoding=True, - ) - assert_matmul_blocked_dequant_with_default_correctness( - 1024, - 1024, - 1024, - source_format="uint", - bit=4, - with_scaling=True, - with_zeros=True, - fast_decoding=True, - ) def test_matmul_fine_grained_dequant_with_default():