diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index c97f5c9ec..d4ffbafc9 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -693,6 +693,10 @@ def sch_dequantize_in_register_with_config( V compute """ + weight_transform_kind = config.intrin_info.weight_transform_kind + if weight_transform_kind == TransformKind.LDMatrixTransform: + return self.sch_warp_memory_prefetch_with_config(func, config) + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel get_mma_intrin_group,) from .intrin import get_lop3_intrin_group diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 227de7ad3..eea256fd9 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -383,14 +383,20 @@ def with_default_config(self): def apply_config( self, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=32, - warp_col_tiles=32, - chunk=16, - num_stages=2, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, enable_rasterization=False, ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B @@ -534,6 +540,9 @@ def __post_init__(self): @dataclass class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler): + # Ladder Transform Config + weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + def apply_config( self, block_row_warps=2, @@ -604,7 +613,7 @@ def apply_config( warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, - transform_kind_b=TransformKind.LDMatrixTransform, + transform_kind_b=self.weight_transform_kind, ) # Define the main kernel using the generated configuration diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index bc13c9d4c..9fe99512c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -5,6 +5,14 @@ MatmulDequantizeScheduler, # noqa: F401 ) +from .finegrained_primitive_tensorcore import ( + MatmulDequantizeFineGrainedScheduler, # noqa: F401 +) + +from .ladder_weight_transform_tensorcore import ( + MatmulDequantizeWeightPropagationScheduler, # noqa: F401 +) + from bitblas.ops.common import TransformKind from typing import Union diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py index fce026c51..5f1a8f5ed 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -312,7 +312,6 @@ def general_dequant_matmul( Zeros, Qzeros, local_size, - local_size_compressed, bx, tx, k, @@ -384,7 +383,6 @@ def _normal_dequant( zeros_buffer: T.Buffer, qzeros_buffer: T.Buffer, local_size: int, - local_size_compressed: int, pid_n: T.Var, tx: T.Var, k: T.Var, @@ -413,9 +411,9 @@ def _normal_dequant_impl( qzeros_buffer: T.Buffer, ): for v in T.serial(0, local_size): - index = (i * threads * local_size_compressed + tx * local_size_compressed + v) - vi = index // (stride_k // num_elems_per_byte) - vj = index % (stride_k // num_elems_per_byte) + index = (i * threads * local_size + tx * local_size + v) + vi = index // stride_k + vj = index % stride_k if not with_scaling: dequant_weight_local[v] = self._decode_func( num_bits, @@ -486,12 +484,9 @@ def _normal_fast_dequant( qzeros_buffer: T.Buffer, func_name: str, pid_n: T.Var, - tx: T.Var, k: T.Var, - i: T.Var, stride_n: int, stride_k: int, - threads: int, ): num_elems_per_byte = self.num_elems_per_byte with_scaling = self.with_scaling diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index c98474ec0..d755ba2f8 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -11,7 +11,6 @@ from bitblas.tl.macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 - TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) from bitblas.ops.common import TransformKind # noqa: F401 from bitblas.ops.base_scheduler import BaseScheduler @@ -31,13 +30,14 @@ _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group # GPU warp configuration for NVIDIA GPUs warp_size = 32 @dataclass -class MatmulDequantizeScheduler(BaseScheduler): +class MatmulDequantizeFineGrainedScheduler(BaseScheduler): # OP Related Config M: Optional[int] = None @@ -60,12 +60,15 @@ class MatmulDequantizeScheduler(BaseScheduler): with_bias: bool = False zeros_mode: Literal["original", "rescale", "quantized"] = "original", - # Default Tile Related Params - block_M: int = 64 - block_N: int = 64 - block_K: int = 32 + # Tensor Core Warp Configuration + block_row_warps: int = 2 + block_col_warps: int = 2 + warp_row_tiles: int = 64 + warp_col_tiles: int = 64 + chunk: int = 32 # Usually determines the K-dimension split size + + # Other Optimization Parameters num_stages: int = 2 - threads: int = 128 enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): @@ -88,36 +91,43 @@ def from_roller_hint(cls, hint: Hint): block_row_warps = block[0] // warp[0] block_col_warps = block[1] // warp[1] - warp_size = 32 # NVIDIA GPU warp size is 32 + warp_row_tiles = warp[0] + warp_col_tiles = warp[1] + chunk = rstep[0] + if num_stages == 1: num_stages = 0 # disable pipelining - tl_hint.block_M = block[0] - tl_hint.block_N = block[1] - tl_hint.block_K = rstep[0] + tl_hint.block_row_warps = block_row_warps + tl_hint.block_col_warps = block_col_warps + tl_hint.warp_row_tiles = warp_row_tiles + tl_hint.warp_col_tiles = warp_col_tiles + tl_hint.chunk = chunk tl_hint.num_stages = num_stages - tl_hint.threads = warp_size * block_row_warps * block_col_warps tl_hint.enable_rasterization = enable_rasterization return tl_hint def get_config_params(self): return { - "block_M": self.block_M, - "block_N": self.block_N, - "block_K": self.block_K, + "block_row_warps": self.block_row_warps, + "block_col_warps": self.block_col_warps, + "warp_row_tiles": self.warp_row_tiles, + "warp_col_tiles": self.warp_col_tiles, + "chunk": self.chunk, "num_stages": self.num_stages, - "threads": self.threads, "enable_rasterization": self.enable_rasterization, } def __repr__(self): return ("{" - f"block_M={self.block_M}," - f"block_N={self.block_N}," - f"block_K={self.block_K}," + f"block_M={self.block_row_warps * self.warp_row_tiles}," + f"block_N={self.block_col_warps * self.warp_col_tiles}," + f"warp_M={self.warp_row_tiles}," + f"warp_N={self.warp_col_tiles}," + f"block_K={self.chunk}," + f"threads={self.block_row_warps * self.block_col_warps * warp_size}," f"num_stages={self.num_stages}," - f"threads={self.threads}," f"enable_rasterization={self.enable_rasterization}" "}") @@ -167,56 +177,71 @@ def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): return self.get_roller_configs(arch, topk) def with_default_config(self): - block_M = getattr(self, "block_M", 64) - block_N = getattr(self, "block_N", 64) - block_K = getattr(self, "block_K", 32) + block_row_warps = getattr(self, "block_row_warps", 2) + block_col_warps = getattr(self, "block_col_warps", 2) + warp_row_tiles = getattr(self, "warp_row_tiles", 32) + warp_col_tiles = getattr(self, "warp_col_tiles", 32) + chunk = getattr(self, "chunk", 32) num_stages = getattr(self, "num_stages", 2) - threads = getattr(self, "threads", 128) enable_rasterization = getattr(self, "enable_rasterization", False) return self.apply_config( - block_M=block_M, - block_N=block_N, - block_K=block_K, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, num_stages=num_stages, - threads=threads, enable_rasterization=enable_rasterization, ) - def _apply_config_dequant_only( + def apply_config( self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, + enable_rasterization=False, ): - assert block_M is not None, "block_M is required" - assert block_N is not None, "block_N is required" - assert block_K is not None, "block_K is required" + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" - assert threads is not None, "threads is required" + M, N, K = self.M, self.N, self.K trans_A, trans_B = self.trans_A, self.trans_B + assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - # check is dequantize only + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) - def check_is_dequantize_only(): - return not self.with_scaling + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = warp_size * (block_row_warps * block_col_warps) - if not check_is_dequantize_only(): - raise ValueError("Not a Dequantize Only Configuration") + fragement_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y - in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + fast_decoding = self.fast_decoding num_bits = self.num_bits storage_dtype = self.storage_dtype + source_format = self.source_format storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - num_elems_per_byte = 8 // num_bits + num_elems_per_byte = self.num_elems_per_byte MAX_TRANSACTION_SIZE_IN_BITS = 128 local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits @@ -227,50 +252,139 @@ def check_is_dequantize_only(): group_size = K A_shape = (M, K) - B_shape = (N, K // storage_nbit * num_bits) + B_shape = (N, K // num_elems_per_byte) + LUT_shape = (group_size, K // num_elems_per_byte) + Scale_shape = (N, K // group_size) + Zeros_shape = (N, K // group_size) + Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) + Bias_shape = (N,) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = self.common_header + import_source + + # Configure the tensor core intrinsic emitter + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) @T.prim_func - def main( + def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, storage_dtype), + LUT: T.Buffer(LUT_shape, in_dtype), + Scale: T.Buffer(Scale_shape, in_dtype), + Qzeros: T.Buffer(Qzeros_shape, storage_dtype), + Zeros: T.Buffer(Zeros_shape, in_dtype), + Bias: T.Buffer(Bias_shape, in_dtype), C: T.Buffer((M, N), out_dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + + A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) tx = T.thread_binding(0, threads, thread="threadIdx.x") + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + }) + T.use_swizzle(10, enable=enable_rasterization) - T.clear(C_local) + T.import_source(import_source) + + T.clear(C_frag) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): - index = i * threads * local_size_compressed + tx * local_size_compressed + v + index = ( + i * threads * local_size_compressed + tx * local_size_compressed + + v) vi = index // (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] - for v in T.serial(0, local_size): - B_dequantize_local[v] = self._decode_func( - num_bits, - B_local[v // num_elems_per_byte], - v % num_elems_per_byte, - dtype=in_dtype, + + if fast_decoding is True: + self._normal_fast_dequant( + B_local, + B_dequantize_local, + Scale, + Zeros, + Qzeros, + func_name, + by, + tx, + ko, + i, + block_N, + block_K, + threads, + ) + else: + self._normal_dequant( + B_local, + B_dequantize_local, + Scale, + Zeros, + Qzeros, + local_size, + local_size_compressed, + bx, + tx, + ko, + i, + block_N, + block_K, + threads, ) for v in T.vectorized(0, local_size): index = i * threads * local_size + tx * local_size + v @@ -278,92 +392,45 @@ def main( vj = index % block_K B_dequantize_shared[vi, vj] = B_dequantize_local[v] - T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - def _apply_config_with_scaling( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - raise NotImplementedError("Scaling Configuration is not implemented") - - def _apply_config_with_scaling_zeros_original_or_rescale( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - raise NotImplementedError("Scaling and Zeros Original Configuration is not implemented") - - def _apply_config_with_scaling_zeros_quantized( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - raise NotImplementedError("Scaling and Zeros Rescale Configuration is not implemented") - - def apply_config( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - assert block_M is not None, "block_M is required" - assert block_N is not None, "block_N is required" - assert block_K is not None, "block_K is required" - assert num_stages is not None, "num_stages is required" - assert threads is not None, "threads is required" - trans_A, trans_B = self.trans_A, self.trans_B - - assert trans_A is False, "Dequantize only implement for trans_A=False currently" - assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" - - with_scaling = self.with_scaling - with_zeros = self.with_zeros - zeros_mode = self.zeros_mode - - args = [block_M, block_N, block_K, num_stages, threads, enable_rasterization] - - dequant_prim_func = None - if not with_scaling: - dequant_prim_func = self._apply_config_dequant_only(*args) - - if not with_zeros: - dequant_prim_func = self._apply_config_with_scaling(*args) - - if zeros_mode in ["original", "rescale"]: - dequant_prim_func = self._apply_config_with_scaling_zeros_original_or_rescale(*args) - elif zeros_mode == "quantized": - dequant_prim_func = self._apply_config_with_scaling_zeros_quantized(*args) - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - if dequant_prim_func is None: - raise ValueError("Unsupported Configuration") - - return self.maybe_simplify(dequant_prim_func) + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=tx, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=tx, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(general_dequant_matmul) @property def _decode_func(self): @@ -375,7 +442,7 @@ def _decode_func(self): source_format = self.source_format storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - bit = self.bit + num_bits = self.num_bits dequant_func = None @@ -385,17 +452,17 @@ def naive_cast_dequant(x): if with_zeros and zeros_mode == "quantized": dequant_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) elif source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed + if num_bits == 8: + # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant else: dequant_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) elif source_format == "int": - if bit == 1: + if num_bits == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. dequant_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) - elif bit == 8: - # 8 bit does not need to be compressed + elif num_bits == 8: + # 8 num_bits does not need to be compressed dequant_func = naive_cast_dequant else: dequant_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) @@ -408,6 +475,190 @@ def naive_cast_dequant(x): return dequant_func + def _normal_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + local_size: int, + local_size_compressed: int, + pid_n: T.Var, + tx: T.Var, + k: T.Var, + i: T.Var, + stride_n: int, + stride_k: int, + threads: int, + ): + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + num_bits = self.num_bits + in_dtype = self.in_dtype + group_size = self.group_size + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + + @T.macro + def _normal_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + for v in T.serial(0, local_size): + index = (i * threads * local_size + tx * local_size + v) + vi = index // (stride_k) + vj = index % (stride_k) + if not with_scaling: + dequant_weight_local[v] = self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + elif not with_zeros: + # Scaling only + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size]) + elif zeros_mode == "original": + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) - zeros_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // + group_size]) * scale_buffer[pid_n * stride_n + vi, + (k * stride_k + vj) // group_size] + elif zeros_mode == "rescale": + dequant_weight_local[v] = ( + self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size] - + zeros_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size]) + elif zeros_mode == "quantized": + dequant_qzeros = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + (k * stride_k + vj) // group_size, + (pid_n * stride_n + vi) // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + + dequant_weight_local[v] = (self._decode_func( + num_bits, + compressed_weight_local[v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros, + dtype=in_dtype, + )) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size] + + return _normal_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + def _normal_fast_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + func_name: str, + pid_n: T.Var, + tx: T.Var, + k: T.Var, + i: T.Var, + stride_n: int, + stride_k: int, + threads: int, + ): + # TODO(lei): un-used arguments should be removed + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + in_dtype = self.in_dtype + group_size = self.group_size + + @T.macro + def _normal_fast_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + if not with_scaling: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + dtype=in_dtype, + ) + elif not with_zeros: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), + dtype=in_dtype, + ) + elif zeros_mode in ["original", "rescale"]: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), + T.address_of(zeros_buffer[pid_n * stride_n, k * stride_k // group_size]), + dtype=in_dtype, + ) + elif zeros_mode == "quantized": + T.call_extern( + func_name, + T.address_of(compressed_weight_local[0]), + T.address_of(dequant_weight_local[0]), + T.address_of(scale_buffer[pid_n * stride_n, k * stride_k // group_size]), + T.address_of(zeros_buffer[pid_n * stride_n, k * stride_k // group_size]), + T.address_of(qzeros_buffer[k * stride_k // group_size, + pid_n * stride_n // num_elems_per_byte]), + dtype=in_dtype, + ) + + return _normal_fast_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + @property + def num_elems_per_byte(self): + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + num_bits = self.num_bits + return storage_nbit // num_bits + def __post_init__(self): - # Add Config Validation - return + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index e69de29bb..bb463e59a 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -0,0 +1,540 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional +from bitblas.tl.utils import ( + get_mma_micro_size, # noqa: F401 + make_swizzle_layout, # noqa: F401 +) +from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 +) +from bitblas.ops.common import TransformKind # noqa: F401 +from dataclasses import dataclass +from bitblas.quantization import ( + _tir_packed_to_unsigned_convert,) +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group +from bitblas.gpu.matmul_analysis import ( + get_propagate_map, + get_ladder_stage3_map, +) + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedScheduler): + + # Ladder Transform Config + weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform + + def apply_config( + self, + block_row_warps: Optional[int] = None, + block_col_warps: Optional[int] = None, + warp_row_tiles: Optional[int] = None, + warp_col_tiles: Optional[int] = None, + chunk: Optional[int] = None, + num_stages: Optional[int] = None, + enable_rasterization=False, + ): + assert block_row_warps is not None, "block_row_warps is required" + assert block_col_warps is not None, "block_col_warps is required" + assert warp_row_tiles is not None, "warp_row_tiles is required" + assert warp_col_tiles is not None, "warp_col_tiles is required" + assert chunk is not None, "chunk is required" + assert num_stages is not None, "num_stages is required" + + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + weight_transform_kind = self.weight_transform_kind + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + assert (weight_transform_kind == TransformKind.LDMatrixTransform + ), "Dequantize only implement for LDMatrixTransform currently" + + in_dtype, out_dtype, accum_dtype = ( + self.in_dtype, + self.out_dtype, + self.accum_dtype, + ) + # Calculate the micro size per warp using a helper function + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = warp_size * (block_row_warps * block_col_warps) + + fragement_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + fast_decoding = self.fast_decoding + + num_bits = self.num_bits + storage_dtype = self.storage_dtype + source_format = self.source_format + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = self.num_elems_per_byte + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + + group_size = self.group_size + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = ( + N // micro_size_y, + K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + LUT_shape = (group_size, K // num_elems_per_byte) + Scale_shape = (N, K // group_size) + Zeros_shape = (N, K // group_size) + Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) + Bias_shape = (N,) + + A_shared_shape = (block_M, block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + source_format=source_format, + source_bit=num_bits, + storage_dtype=storage_dtype, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + storage_scope="warp", # to get the ladder transform lop3 intrin + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + import_source = self.common_header + import_source + + # Configure the tensor core intrinsic emitter with ladder transform + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=weight_transform_kind, + num_elems_per_byte=num_elems_per_byte, + ) + + vec_load_qb = 16 + if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * block_K // num_elems_per_byte // threads + + @T.prim_func + def general_dequant_matmul( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + LUT: T.Buffer(LUT_shape, in_dtype), + Scale: T.Buffer(Scale_shape, in_dtype), + Qzeros: T.Buffer(Qzeros_shape, storage_dtype), + Zeros: T.Buffer(Zeros_shape, in_dtype), + Bias: T.Buffer(Bias_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + C_shared = T.alloc_shared(C_shared_shape, out_dtype) + + A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size // num_elems_per_byte), + storage_dtype) + B_dequantize_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + T.use_swizzle(10, enable=enable_rasterization) + + T.import_source(import_source) + + T.clear(C_frag) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + T.copy(A[by * block_M, ko * block_K], A_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + idx = i * threads * vec_load_qb + tx * vec_load_qb + v + vkk = idx % (micro_size_k // num_elems_per_byte) + vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( + block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // + (block_K // micro_size_k)) % ( + block_N // micro_size_y) + B_shared[vj, vk, vjj, vkk] = B[ + bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, + vjj, + vkk, + ] + + # Perform the matrix multiplication on tensor core fragments + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=tx, + ) + + # Load B fragment + mma_emitter.ldmatrix_b( + B_frag, + B_shared, + ki, + thread_bindings=tx, + ) + + if fast_decoding is True: + self._normal_fast_dequant( + B_frag, + B_dequantize_frag, + Scale, + Zeros, + Qzeros, + func_name, + local_size, + warp_cols, + bx, + tx, + mma_emitter, + ko, + ki, + block_N, + block_K, + ) + else: + self._normal_dequant( + B_frag, + B_dequantize_frag, + Scale, + Zeros, + Qzeros, + local_size, + warp_cols, + bx, + tx, + mma_emitter, + ko, + ki, + block_N, + block_K, + ) + + # Matrix multiplication on fragments + mma_emitter.mma(A_frag, B_dequantize_frag, C_frag) + + # Store the result back to C shared memory + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=tx, + ) + + # Store results from shared memory to global memory + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return self.maybe_simplify(general_dequant_matmul) + + def _normal_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + local_size: int, + warp_cols: int, + pid_n: T.Var, + thread_bindings: T.Var, + mma_emitter: TensorCoreIntrinEmitterWithLadderTransform, + ko: T.Var, + ki: T.Var, + stride_n: int, + stride_k: int, + ): + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + num_bits = self.num_bits + in_dtype = self.in_dtype + group_size = self.group_size + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + micro_size_k = mma_emitter.micro_size_k + k_inner_stride = micro_size_k // local_size + + @T.macro + def _normal_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + for j in T.serial(warp_cols): + for v in T.serial(0, local_size): + tx = thread_bindings % mma_emitter.WARP_SIZE + tz = (thread_bindings // (mma_emitter.WARP_SIZE * mma_emitter.block_row_warps) + ) % mma_emitter.block_col_warps + vi = ( + tz * (warp_cols * mma_emitter.WARP_SIZE // k_inner_stride) + j * + (mma_emitter.WARP_SIZE // k_inner_stride) + (tx // k_inner_stride)) + vj = ki * micro_size_k + (tx % k_inner_stride) * local_size + v + remaped_i, remaped_j = self.get_param_indices( + pid_n * stride_n + vi, + ko * stride_k + vj, + transform_kind=TransformKind.LDMatrixTransform, + in_dtype=in_dtype, + matrix_name="B", + group_size=group_size, + ) + if not with_scaling: + dequant_weight_local[j * local_size + v] = self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + elif not with_zeros: + dequant_weight_local[j * local_size + v] = ( + self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[remaped_i, remaped_j]) + elif zeros_mode == "original": + dequant_weight_local[j * local_size + v] = (self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) - zeros_buffer[remaped_i, remaped_j]) * scale_buffer[remaped_i, remaped_j] + elif zeros_mode == "rescale": + dequant_weight_local[j * local_size + v] = ( + self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) * scale_buffer[remaped_i, remaped_j] - + zeros_buffer[remaped_i, remaped_j]) + elif zeros_mode == "quantized": + dequant_qzeros = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + qzeros_buffer[ + remaped_i, + remaped_j // num_elems_per_byte, + ], + (pid_n * stride_n + vi) % num_elems_per_byte, + dtype=storage_dtype, + ) + + dequant_weight_local[j * local_size + v] = (self._decode_func( + num_bits, + compressed_weight_local[j * local_size // num_elems_per_byte + + v // num_elems_per_byte], + v % num_elems_per_byte, + zero=dequant_qzeros, + dtype=in_dtype, + )) * scale_buffer[remaped_i, remaped_j] + + return _normal_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + def _normal_fast_dequant( + self, + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + func_name: str, + local_size: int, + warp_cols: int, + pid_n: T.Var, + thread_bindings: T.Var, + mma_emitter: TensorCoreIntrinEmitterWithLadderTransform, + ko: T.Var, + ki: T.Var, + stride_n: int, + stride_k: int, + ): + num_elems_per_byte = self.num_elems_per_byte + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + in_dtype = self.in_dtype + group_size = self.group_size + micro_size_k = mma_emitter.micro_size_k + k_inner_stride = micro_size_k // local_size + grouped_k = scale_buffer.shape[-1] + + @T.macro + def _normal_fast_dequant_impl( + compressed_weight_local: T.Buffer, + dequant_weight_local: T.Buffer, + scale_buffer: T.Buffer, + zeros_buffer: T.Buffer, + qzeros_buffer: T.Buffer, + ): + for j in T.serial(warp_cols): + tx = thread_bindings % mma_emitter.WARP_SIZE + tz = (thread_bindings // (mma_emitter.WARP_SIZE * mma_emitter.block_row_warps) + ) % mma_emitter.block_col_warps + vi = ( + tz * (warp_cols * mma_emitter.WARP_SIZE // k_inner_stride) + j * + (mma_emitter.WARP_SIZE // k_inner_stride) + (tx // k_inner_stride)) + vj = ki * micro_size_k + (tx % k_inner_stride) * local_size + remapped_i, remapped_j = self.get_param_indices( + pid_n * stride_n + vi, + ko * stride_k + vj, + transform_kind=TransformKind.LDMatrixTransform, + in_dtype=in_dtype, + matrix_name="B", + group_size=group_size, + ) + if not with_scaling: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), + T.address_of(dequant_weight_local[j * local_size]), + dtype=in_dtype, + ) + elif not with_zeros: + # Scaling only + T.call_extern( + func_name, + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), + T.address_of(dequant_weight_local[j * local_size]), + T.address_of(scale_buffer[remapped_i, remapped_j]), + local_size * grouped_k, + local_size, + dtype=in_dtype, + ) + elif zeros_mode in ["original", "rescale"]: + T.call_extern( + func_name, + T.address_of(compressed_weight_local[j * local_size // num_elems_per_byte]), + T.address_of(dequant_weight_local[j * local_size]), + T.address_of(scale_buffer[remapped_i, remapped_j]), + T.address_of(zeros_buffer[remapped_i, remapped_j]), + local_size * grouped_k, + local_size, + dtype=in_dtype, + ) + # TODO: Implement quantized zeros + + return _normal_fast_dequant_impl( + compressed_weight_local, + dequant_weight_local, + scale_buffer, + zeros_buffer, + qzeros_buffer, + ) + + def get_param_indices( + self, + rl, + rr, + l=16, + r=16, + transform_kind=TransformKind.LDMatrixTransform, # noqa: E741 + trans=True, + in_dtype="float16", + matrix_name="B", + group_size=1, + ): # noqa: E741 + intra_index_map, _ = get_propagate_map(trans=trans, dtype=in_dtype, matrix_name=matrix_name) + + ladder_stage3_index_map, ladder_stage3_inverse_index_map = ( + get_ladder_stage3_map(dtype=in_dtype)) + + # assume the param layout is n, k + + warp_i, warp_j = rl % l, rr % r + + spatial_i, spatial_j = rl // l, rr // r + + # If is stage3 ladder transform + if transform_kind > 2: + warp_i, warp_j = ladder_stage3_inverse_index_map.map_indices([warp_i, warp_j]) + + warp_i, warp_j = intra_index_map.map_indices([warp_i, warp_j]) + new_indices = ( + spatial_i * l + warp_i, + (spatial_j * r + warp_j) // group_size, + ) + + return new_indices + + def __post_init__(self): + # Legalize group_size + if self.with_scaling and self.group_size == -1: + object.__setattr__(self, "group_size", self.K) diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 58f595984..1a83c0d18 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -370,7 +370,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( layout="nt", zeros_mode="original", ): - assert with_scaling, "Currently The test only support with scaling" + if group_size == -1: group_size = K propagate_b = 3 @@ -408,10 +408,10 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( matmul_func, config=bitblas.base.Hint.from_dict({ "arch": arch, - "block": [16, 128], - "warp": [16, 32], - "rstep": [128], - "pipeline_stage": 4, + "block": [128, 128], + "warp": [64, 64], + "rstep": [32], + "pipeline_stage": 2, "use_async": True, "intrin_info": intrin_info, "shared_scope": "shared.dyn", @@ -419,7 +419,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( "b": 8, "a": 8 }, - "block_reduction_depth": 2, + "block_reduction_depth": 1, }), ) @@ -429,6 +429,8 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( "tir.disable_cse_tir": True }): rt_mod = tvm.build(block_reduce_sch.mod, target=target) + src_code = rt_mod.imported_modules[0].get_source() + assert src_code is not None check_reduce(rt_mod) @@ -500,28 +502,38 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( transformed_b = transformed_b.cuda() c = c.cuda() scale = scale.cuda() - if zeros is not None: + args = [a, transformed_b] + if with_scaling: + args.append(scale) + if with_scaling and with_zeros: zeros = zeros.cuda() - torch_func(a, transformed_b, scale, zeros, c) - else: - torch_func(a, transformed_b, scale, c) - - rescale_b = torch.empty_like(b, dtype=torch.float16) - for i in range(N): - for j in range(K): - if with_zeros: - if zeros_mode == "original": - rescale_b[i, - j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // group_size] - elif zeros_mode == "rescale": - rescale_b[i, - j] = b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size] + args.append(zeros) + args.append(c) + + torch_func(*args) + + args = [a] + if with_scaling: + + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + if zeros_mode == "original": + rescale_b[i, + j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // group_size] + elif zeros_mode == "rescale": + rescale_b[i, + j] = b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size] + else: + raise NotImplementedError else: - raise NotImplementedError - else: - rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) - ref_c = torch.matmul(a, rescale_b.t().cuda()) + ref_c = torch.matmul(*args) print("rescale_b is \n", c) print("ref_c is \n", ref_c) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 280c170ac..31c3de7d1 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -11,13 +11,18 @@ ) from bitblas.ops.general_matmul.tilelang.dequantize import ( - MatmulDequantizeScheduler,) + MatmulDequantizeScheduler, + MatmulDequantizeFineGrainedScheduler, + MatmulDequantizeWeightPropagationScheduler, +) import torch import torch.backends torch.manual_seed(0) +verbose = False + def assert_matmul_blocked_with_default_correctness( M, @@ -166,64 +171,6 @@ def assert_matmul_fine_grained_with_default_correctness( torch.matmul(A, B.T).to(getattr(torch, out_dtype)) if trans_B else torch.matmul(A, B).to( getattr(torch, out_dtype))) - # from bitblas.ops import Matmul, MatmulConfig - # matmul_config = MatmulConfig( - # M=M, - # N=N, - # K=K, - # propagate_a=False, - # propagate_b=False, - # ) - # matmul = Matmul(matmul_config, enable_tuning=False) - # prim_func = matmul.prim_func - # intrin_info = bitblas.base.hint.IntrinInfo( - # in_dtype=in_dtype, - # out_dtype=accum_dtype, - # trans_b=True, - # input_transform_kind=0, - # weight_transform_kind=0, - # ) - - # arch = bitblas.base.CUDA(target="cuda") - - # sch = bitblas.gpu.MatmulTensorizationMMA().apply_config( - # prim_func, - # config=bitblas.base.Hint.from_dict({ - # "arch": arch, - # "block": [64, 64], - # "warp": [32, 32], - # "rstep": [32], - # "pipeline_stage": 2, - # "use_async": True, - # "intrin_info": intrin_info, - # "shared_scope": "shared.dyn", - # "vectorize": { - # "b": 8, - # "a": 8 - # }, - # }), - # ) - - # with tvm.transform.PassContext(config={ - # "tir.use_async_copy": True, - # "tir.merge_static_smem": False - # }): - # rt_mod = tvm.build(sch.mod, target="cuda") - # from tvm.contrib.dlpack import to_pytorch_func - - # torch_func = to_pytorch_func(rt_mod) - - # matmul_c = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - # torch_func(A, B, matmul_c) - - # with open("debug/matmul_ref.cu", "w") as f: - # f.write(rt_mod.imported_modules[0].get_source()) - - # with open("debug/matmul_tl.cu", "w") as f: - # f.write(src_code) - - # torch.testing.assert_close(matmul_c, ref_c, rtol=1e0, atol=1e-1) - torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e-1) @@ -439,6 +386,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( ): import numpy as np from bitblas.quantization import general_compress, interleave_weight + matmul = MatmulDequantizeScheduler( M=M, N=N, @@ -462,7 +410,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None - + print(src_code) input_shape = (M, K) weight_shape = (N, K) output_shape = (M, N) @@ -496,17 +444,17 @@ def assert_matmul_blocked_dequant_with_default_correctness( if with_scaling: if group_size == -1: group_size = K - permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) if with_zeros: if zeros_mode == "original": permuted_inputs.append( torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) elif zeros_mode == "rescale": - original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) scaled_zeros = original_zeros * permuted_inputs[-1] permuted_inputs.append(scaled_zeros) elif zeros_mode == "quantized": - original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros) qzeros = general_compress( original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) permuted_inputs.append(torch.from_numpy(qzeros).cuda()) @@ -521,7 +469,30 @@ def assert_matmul_blocked_dequant_with_default_correctness( print(permuted_inputs[-1]) - ref_result = torch.matmul(inputs[0], inputs[1].t().to(torch.float16)) + args = [inputs[0]] + b = inputs[1] + if with_scaling: + scale = permuted_inputs[2] + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + zeros = permuted_inputs[3] + if zeros_mode == "original": + rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // + group_size] + elif zeros_mode == "rescale": + rescale_b[i, j] = ( + b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + else: + raise NotImplementedError + else: + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) + + ref_result = torch.matmul(*args) print(ref_result) if zeros_mode == "rescale": @@ -530,6 +501,289 @@ def assert_matmul_blocked_dequant_with_default_correctness( torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e2) +def assert_matmul_fine_grained_dequant_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", +): + import numpy as np + from bitblas.quantization import general_compress, interleave_weight + + matmul = MatmulDequantizeFineGrainedScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + # lop3 transformation + if fast_decoding: + qw = interleave_weight(qw, bit, target_dtype=in_dtype) + permuted_inputs.append(torch.from_numpy(qw).cuda()) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) + elif zeros_mode == "rescale": + original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = (torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros) + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + + permuted_inputs.append(inputs[2]) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(*permuted_inputs) + + print(permuted_inputs[-1]) + + args = [inputs[0]] + b = inputs[1] + if with_scaling: + scale = permuted_inputs[2] + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + if zeros_mode == "original": + rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // + group_size] + elif zeros_mode == "rescale": + rescale_b[i, j] = ( + b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + else: + raise NotImplementedError + else: + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) + + ref_result = torch.matmul(*args) + + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e2) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e2) + + +def assert_matmul_weight_transform_dequant_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", +): + import numpy as np + from bitblas.quantization import general_compress, interleave_weight + + matmul = MatmulDequantizeWeightPropagationScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, + ).with_default_config() + if verbose: + print(matmul) + mod, params = tl.lower(matmul) + + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + if verbose: + print(src_code) + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) + if group_size == -1: + group_size = K + + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + storage_dtype=storage_dtype, + propagate_kind="B", + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(intweight.cpu()).cuda().reshape(N, K) + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + qw = general_compress(LB.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + + # lop3 transformation + if fast_decoding: + qw = interleave_weight(qw, bit, target_dtype=in_dtype) + qw_shape = [int(v) for v in matmul.buffer_map[matmul.params[1]].shape] + qw = qw.reshape(qw_shape) + permuted_inputs.append(torch.from_numpy(qw).cuda()) + if with_scaling: + # permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + permuted_inputs.append(torch.randn((N, K // group_size), dtype=torch.float16).cuda()) + + zeros = None + if with_zeros: + if zeros_mode == "original": + zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * maxq + elif zeros_mode == "rescale": + scale = permuted_inputs[2] + original_zeros = (torch.ones([N, K // group_size], dtype=torch.float16).cuda() * maxq) + zeros = -(original_zeros * scale.cuda()) + else: + raise NotImplementedError + + if with_scaling and with_zeros: + permuted_inputs.append(zeros) + + permuted_inputs.append(inputs[2]) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(*permuted_inputs) + + print(permuted_inputs[-1]) + + args = [inputs[0]] + b = inputs[1] + + if with_scaling: + scale = permuted_inputs[2] + rescale_b = torch.empty_like(b, dtype=torch.float16) + for i in range(N): + for j in range(K): + if with_zeros: + zeros = permuted_inputs[3] + if zeros_mode == "original": + rescale_b[i, j] = (b[i, j] - zeros[i, j // group_size]) * scale[i, j // + group_size] + elif zeros_mode == "rescale": + rescale_b[i, j] = ( + b[i, j] * scale[i, j // group_size] + zeros[i, j // group_size]) + else: + raise NotImplementedError + else: + rescale_b[i, j] = b[i, j] * scale[i, j // group_size] + args.append(rescale_b.t().cuda()) + else: + args.append(b.t().cuda().to(torch.float16)) + + ref_result = torch.matmul(*args) + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e0) + + def test_matmul_blocked(): # Default assert_matmul_blocked_with_default_correctness(1024, 1024, 1024) @@ -569,11 +823,25 @@ def test_matmul_blocked_dequant_with_default(): assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True) assert_matmul_blocked_dequant_with_default_correctness( - 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True, with_zeros=True) + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + with_zeros=True, + ) assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) assert_matmul_blocked_dequant_with_default_correctness( - 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True, fast_decoding=True) + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + ) assert_matmul_blocked_dequant_with_default_correctness( 1024, 1024, @@ -582,7 +850,79 @@ def test_matmul_blocked_dequant_with_default(): bit=4, with_scaling=True, with_zeros=True, - fast_decoding=True) + fast_decoding=True, + ) + + +def test_matmul_fine_grained_dequant_with_default(): + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=2) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + with_zeros=True, + ) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + ) + assert_matmul_fine_grained_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + with_zeros=True, + fast_decoding=True, + ) + + +def test_matmul_weight_transform_dequant_with_default(): + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=2) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, with_scaling=True, with_zeros=True) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4, fast_decoding=True) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + ) + assert_matmul_weight_transform_dequant_with_default_correctness( + 1024, + 1024, + 1024, + source_format="uint", + bit=4, + with_scaling=True, + fast_decoding=True, + with_zeros=True, + ) if __name__ == "__main__": diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index bb3f38d24..620ef5be7 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -172,18 +172,18 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( micro_size_k = 32 # This is a debug config - block_row_warps = 1 - block_col_warps = 4 + block_row_warps = 2 + block_col_warps = 2 - warp_rows = 1 - warp_cols = 2 + warp_rows = 4 + warp_cols = 4 warp_row_tiles = micro_size_x * warp_rows warp_col_tiles = micro_size_y * warp_cols shared_scope = "shared.dyn" # Pipeline Stage stage = 2 - reduce_k = 2 + reduce_k = 1 block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles @@ -423,6 +423,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct # Get Reference Result ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print("Ref C: ", ref_c) + print("C: ", C) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -437,5 +439,4 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): if __name__ == "__main__": - # bitblas.testing.main() - run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) + bitblas.testing.main() diff --git a/tutorials/.gitignore b/tutorials/.gitignore new file mode 100644 index 000000000..4ffe21f30 --- /dev/null +++ b/tutorials/.gitignore @@ -0,0 +1,3 @@ +progress +debug* +.ipynb* diff --git a/tutorials/1.fast_and_efficient_codegen.ipynb b/tutorials/1.fast_and_efficient_codegen.ipynb new file mode 100644 index 000000000..4c34d06e6 --- /dev/null +++ b/tutorials/1.fast_and_efficient_codegen.ipynb @@ -0,0 +1,1427 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "86b403bf-fadd-4d14-81a1-d805eed113ae", + "metadata": {}, + "source": [ + "# Fast and Efficient Code Generation with BitBLAS Roller Component\n", + "\n", + "Reimplemented and improved from **OSDI 22'Roller**: https://www.usenix.org/system/files/osdi22-zhu.pdf\n", + "\n", + "Core Code: https://github.com/microsoft/BitBLAS/blob/main/bitblas/base/roller\n", + "\n", + "Only takes seconds to optimize high performance kernels via hardware-aware white box search space recommendation.\n", + "\n", + "
\n", + " \"BitBLAS\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8a8a16cc-64c2-4568-bb1f-9ef391a6ebd7", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize Import\n", + "import bitblas\n", + "from bitblas import tvm as tvm" + ] + }, + { + "cell_type": "markdown", + "id": "8ebb1b8b-9348-4ac5-8a5b-32658e414012", + "metadata": {}, + "source": [ + "## 1. Get start with an elememt-wise add\n", + "tensor expression: B = A + 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e61b868a-0e69-4e46-a517-7b91a043a63f", + "metadata": {}, + "outputs": [], + "source": [ + "from tvm import te\n", + "\n", + "A = te.placeholder((1024, 1024), name=\"A\", dtype=\"float16\")\n", + "\n", + "def fcompute(i, j):\n", + " return A[i, j] + 1.0\n", + "\n", + "B = te.compute((1024, 1024), fcompute, name=\"B\")\n", + "\n", + "args = [A, B]\n", + "func = te.create_prim_func(args)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "289e13e2-6eca-4b11-a788-a9ac1a7a8acf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def main(A: T.Buffer((1024, 1024), \"float16\"), B: T.Buffer((1024, 1024), \"float32\")):\n", + " T.func_attr({\"tir.noalias\": T.bool(True)})\n", + " # with T.block(\"root\"):\n", + " for i, j in T.grid(1024, 1024):\n", + " with T.block(\"B\"):\n", + " v_i, v_j = T.axis.remap(\"SS\", [i, j])\n", + " T.reads(A[v_i, v_j])\n", + " T.writes(B[v_i, v_j])\n", + " B[v_i, v_j] = T.Cast(\"float32\", A[v_i, v_j]) + T.float32(1)\n" + ] + } + ], + "source": [ + "print(func) # TIR Script Function" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ca3448f5-a796-4089-981b-eb4ff9f17a48", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "target='nvidia/geforce-rtx-4090'\n" + ] + } + ], + "source": [ + "# import fast tunning related toolkits\n", + "from bitblas.base.roller.policy import DefaultPolicy\n", + "from bitblas.base.arch import CUDA\n", + "from bitblas.base.utils import apply_and_build\n", + "\n", + "target = bitblas.auto_detect_nvidia_target()\n", + "print(f\"{target=}\")\n", + "arch = CUDA(target)" + ] + }, + { + "cell_type": "markdown", + "id": "17230883-f5cb-40d4-8ed2-f069a5c07d2a", + "metadata": {}, + "source": [ + "### 2. Example Arch: CUDA\n", + "Codebase: https://github.com/microsoft/BitBLAS/blob/main/bitblas/base/arch/cuda.py\n", + "\n", + "```python\n", + "class CUDA(TileDevice):\n", + "\n", + " def __init__(self, target: Union[Target, str]):\n", + " if isinstance(target, str):\n", + " target = tvm.target.Target(target)\n", + " self.target = target\n", + " self.sm_version = check_sm_version(self.target.arch)\n", + " device = tvm.runtime.cuda(0)\n", + " if not device.exist:\n", + " raise RuntimeError(\"Cannot find cuda device 0.\")\n", + " self.device: tvm.runtime.Device = device\n", + " self.platform: str = \"CUDA\"\n", + " self.smem_cap = device.max_shared_memory_per_block\n", + " self.compute_max_core = device.multi_processor_count\n", + " self.warp_size = device.warp_size\n", + " self.compute_capability = device.compute_version.replace(\".\", \"\")\n", + " self.reg_cap: int = 65536\n", + " self.max_smem_usage: int = 2 * self.smem_cap\n", + " self.sm_partition: int = 4\n", + " self.l2_cache_size_bytes: int = target.l2_cache_size_bytes\n", + " # the number of transaction size in bytes\n", + " self.transaction_size: List[int] = [32, 128] # in bytes\n", + " # bandwidth in MB/s, will be used for recommend basic tile size\n", + " # TODO(lei): find some way to get the real bandwidth\n", + " # However, the ratio of bandwidth between different devices can\n", + " # be similar. The bandwidth can work for another devices as well.\n", + " self.bandwidth: List[int] = [750, 12080]\n", + " # get the available tensor instructions during runtime to avoid\n", + " # the dependency of the tensor intrinsics registration\n", + " self.available_tensor_instructions: List[TensorInstruction] = None\n", + "\n", + " def get_avaliable_tensorintrin_shapes(self):\n", + " from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group, get_mma_intrin_group\n", + "\n", + " self.available_tensor_instructions = (\n", + " TensorInstruction(\"mma\", get_mma_intrin_group, [16, 16]),\n", + " TensorInstruction(\"wmma\", get_wmma_intrin_group, [16, 16]),\n", + " )\n", + " return [t.shape for t in self.available_tensor_instructions]\n", + "\n", + " def __repr__(self):\n", + " return f\"CUDA({self.target})\"\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "513fa032-0d29-482f-ba37-96823c826391", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'block': [128, 128], 'thread': [16, 8], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [64, 128], 'thread': [8, 16], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [128, 64], 'thread': [16, 8], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [128, 256], 'thread': [8, 16], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [256, 128], 'thread': [16, 8], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [64, 64], 'thread': [16, 8], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [32, 128], 'thread': [8, 16], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [128, 32], 'thread': [16, 8], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [64, 256], 'thread': [8, 16], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [256, 64], 'thread': [16, 8], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [32, 64], 'thread': [8, 16], 'rstep': [128], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [64, 32], 'thread': [16, 8], 'rstep': [128], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [32, 32], 'thread': [16, 8], 'rstep': [128], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [16, 128], 'thread': [4, 32], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [128, 16], 'thread': [32, 4], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [32, 256], 'thread': [4, 32], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [256, 32], 'thread': [32, 4], 'rstep': [64], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [16, 64], 'thread': [8, 16], 'rstep': [128], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [64, 16], 'thread': [16, 8], 'rstep': [128], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n", + "{'block': [16, 32], 'thread': [8, 16], 'rstep': [256], 'step': [1, 2], 'vectorize': {'A': 8, 'B': 8}}\n" + ] + } + ], + "source": [ + "policy = DefaultPolicy(func=func, arch=arch)\n", + "configs = policy.emit_config(topk=20)\n", + "\n", + "for config in configs:\n", + " print(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9ccebcf1-bdca-4734-a8e6-0c6b9ef5332f", + "metadata": {}, + "outputs": [], + "source": [ + "bitblas.set_log_level(\"Debug\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d72a8c4f-4947-473d-a537-53b0aafd605e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [2, 1024], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [4, 512], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [8, 256], 'thread': [8, 16], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [16, 128], 'thread': [16, 8], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [32, 64], 'thread': [16, 8], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [1, 1024], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [2, 512], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [4, 256], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [8, 128], 'thread': [8, 16], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [16, 64], 'thread': [16, 8], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [1, 512], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [2, 256], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [4, 128], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [8, 64], 'thread': [8, 16], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [1, 256], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [2, 128], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [4, 64], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [1, 128], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [2, 64], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:57:34 [BitBLAS:DEBUG]: Apply config {'block': [1, 64], 'thread': [1, 64], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [2, 1024], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [4, 512], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [8, 256], 'thread': [8, 16], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [16, 128], 'thread': [16, 8], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [32, 64], 'thread': [16, 8], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [1, 1024], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [2, 512], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.006 ms\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Evaluation with config {'block': [4, 256], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:58:01 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [8, 128], 'thread': [8, 16], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [16, 64], 'thread': [16, 8], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [1, 512], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [2, 256], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [4, 128], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [8, 64], 'thread': [8, 16], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [1, 256], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [2, 128], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [4, 64], 'thread': [4, 32], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [1, 128], 'thread': [1, 128], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.007 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [2, 64], 'thread': [2, 64], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Evaluation with config {'block': [1, 64], 'thread': [1, 64], 'rstep': []}\n", + "2024-10-24 12:58:02 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n" + ] + } + ], + "source": [ + "cpresults, best = apply_and_build(func, configs, arch, parallel_build=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1ad2c473-d3c3-4ee1-9928-2ed1fc44b03b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import ir as I\n", + "# from tvm.script import tir as T\n", + "\n", + "@I.ir_module\n", + "class Module:\n", + " @T.prim_func\n", + " def main(A: T.Buffer((1024, 1024), \"float16\"), B: T.Buffer((1024, 1024), \"float32\")):\n", + " T.func_attr({\"tir.noalias\": T.bool(True)})\n", + " # with T.block(\"root\"):\n", + " for ax0_0_ax1_0_fused in T.thread_binding(2048, thread=\"blockIdx.x\"):\n", + " for ax1_1_0 in T.thread_binding(4, thread=\"vthread.x\"):\n", + " for ax0_1_0 in T.thread_binding(1, thread=\"vthread.y\"):\n", + " for ax0_1_1_0_ax1_1_1_0_fused in T.thread_binding(128, thread=\"threadIdx.x\"):\n", + " for ax0_1_1_1, ax1_1_1_1 in T.grid(1, 1):\n", + " with T.block(\"B\"):\n", + " v0 = T.axis.spatial(1024, ax0_0_ax1_0_fused // 2 + ax0_1_0 + ax0_1_1_1)\n", + " v1 = T.axis.spatial(1024, ax0_0_ax1_0_fused % 2 * 512 + ax1_1_0 * 128 + ax0_1_1_0_ax1_1_1_0_fused + ax1_1_1_1)\n", + " T.reads(A[v0, v1])\n", + " T.writes(B[v0, v1])\n", + " B[v0, v1] = T.Cast(\"float32\", A[v0, v1]) + T.float32(1)\n" + ] + } + ], + "source": [ + "# get the scheduled ir\n", + "print(best.sch.mod)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0b596f7d-29ae-4db9-a207-287638a97053", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n", + "\n", + "typedef unsigned short uint16_t;\n", + "typedef unsigned char uint8_t;\n", + "typedef signed char int8_t;\n", + "typedef int int32_t;\n", + "typedef unsigned long long uint64_t;\n", + "typedef unsigned int uint32_t;\n", + "\n", + "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n", + "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n", + "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n", + "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n", + " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " }\n", + "\n", + "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const T& a) { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " }\n", + "\n", + "class TVM_ALIGNED(2) half {\n", + " public:\n", + " uint16_t half_;\n", + "\n", + " static TVM_XINLINE half Binary(uint16_t value) {\n", + " half res;\n", + " res.half_ = value;\n", + " return res;\n", + " }\n", + "\n", + " TVM_XINLINE half() {}\n", + "\n", + " TVM_XINLINE half(const float& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const long long& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n", + "\n", + " TVM_XINLINE operator float() const { \\\n", + " return float(half2float(half_)); \\\n", + " } \\\n", + " TVM_XINLINE operator float() const volatile { \\\n", + " return float(half2float(half_)); \\\n", + " }\n", + "\n", + "\n", + " TVM_HALF_ASSIGNOP(+=, +)\n", + " TVM_HALF_ASSIGNOP(-=, -)\n", + " TVM_HALF_ASSIGNOP(*=, *)\n", + " TVM_HALF_ASSIGNOP(/=, /)\n", + "\n", + " TVM_XINLINE half operator+() {\n", + " return *this;\n", + " }\n", + "\n", + " TVM_XINLINE half operator-() {\n", + " return half(-float(*this));\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) volatile {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) volatile {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " private:\n", + " union Bits {\n", + " float f;\n", + " int32_t si;\n", + " uint32_t ui;\n", + " };\n", + "\n", + " static int const fp16FractionBits = 10;\n", + " static int const fp32FractionBits = 23;\n", + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n", + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n", + " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n", + " static int const shiftSign = 16;\n", + " static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n", + "\n", + " static int32_t const infN = 0x7F800000; // flt32 infinity\n", + " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n", + " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n", + " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n", + " static int32_t const signN = 0x80000000; // flt32 sign bit\n", + "\n", + " static int32_t const infC = infN >> shift;\n", + " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n", + " static int32_t const maxC = maxN >> shift;\n", + " static int32_t const minC = minN >> shift;\n", + " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n", + "\n", + " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n", + " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n", + "\n", + " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n", + " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n", + "\n", + " static int32_t const maxD = infC - maxC - 1;\n", + " static int32_t const minD = minC - subC - 1;\n", + "\n", + " TVM_XINLINE uint16_t float2half(const float& value) const {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " // Same as above routine, except for addition of volatile keyword\n", + " TVM_XINLINE uint16_t float2half(\n", + " const volatile float& value) const volatile {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(const uint16_t& value) const {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(\n", + " const volatile uint16_t& value) const volatile {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE void constructor(const T& value) {\n", + " half_ = float2half(float(value));\n", + " }\n", + "};\n", + "\n", + "TVM_HALF_OPERATOR(half, +)\n", + "TVM_HALF_OPERATOR(half, -)\n", + "TVM_HALF_OPERATOR(half, *)\n", + "TVM_HALF_OPERATOR(half, /)\n", + "TVM_HALF_OPERATOR(bool, >)\n", + "TVM_HALF_OPERATOR(bool, <)\n", + "TVM_HALF_OPERATOR(bool, >=)\n", + "TVM_HALF_OPERATOR(bool, <=)\n", + "\n", + "TVM_XINLINE half __float2half_rn(const float a) {\n", + " return half(a);\n", + "}\n", + "#else\n", + "#include \n", + "__device__ half max(half a, half b)\n", + "{\n", + " return __hgt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "__device__ half min(half a, half b)\n", + "{\n", + " return __hlt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "#endif\n", + "\n", + "\n", + "// Pack two half values.\n", + "static inline __device__ __host__ unsigned\n", + "__pack_half2(const half x, const half y) {\n", + " unsigned v0 = *((unsigned short *)&x);\n", + " unsigned v1 = *((unsigned short *)&y);\n", + " return (v1 << 16) | v0;\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float tmp_y = __half2float(y); \\\n", + " float result = FP32_MATH_NAME(tmp_x, tmp_y); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float result = FP32_MATH_NAME(tmp_x); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "// Some fp16 math functions are not supported in cuda_fp16.h,\n", + "// so we define them here to make sure the generated CUDA code\n", + "// is valid.\n", + "#if defined(__CUDA_ARCH__)\n", + "#if (__CUDA_ARCH__ >= 530)\n", + "CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)\n", + "#else\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)\n", + "#endif\n", + "#endif\n", + "\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY\n", + "\n", + "struct __align__(8) half4 {\n", + " __half x, y, z, w;\n", + " __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}\n", + " __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}\n", + "\n", + "};\n", + "__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {\n", + " return half4(x, y, z, w);\n", + "}\n", + "\n", + "#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n", + " (__CUDACC_VER_MAJOR__ > 11))\n", + "#define TVM_ENABLE_L2_PREFETCH 1\n", + "#else\n", + "#define TVM_ENABLE_L2_PREFETCH 0\n", + "#endif\n", + "\n", + "#ifdef _WIN32\n", + " using uint = unsigned int;\n", + " using uchar = unsigned char;\n", + " using ushort = unsigned short;\n", + " using int64_t = long long;\n", + " using uint64_t = unsigned long long;\n", + "#else\n", + " #define uint unsigned int\n", + " #define uchar unsigned char\n", + " #define ushort unsigned short\n", + " #define int64_t long long\n", + " #define uint64_t unsigned long long\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) \n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1\n", + "#else\n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0\n", + "#endif\n", + "extern \"C\" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ A, float* __restrict__ B);\n", + "extern \"C\" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ A, float* __restrict__ B) {\n", + " B[((((int)blockIdx.x) * 512) + ((int)threadIdx.x))] = (((float)A[((((int)blockIdx.x) * 512) + ((int)threadIdx.x))]) + 1.000000e+00f);\n", + " B[(((((int)blockIdx.x) * 512) + ((int)threadIdx.x)) + 128)] = (((float)A[(((((int)blockIdx.x) * 512) + ((int)threadIdx.x)) + 128)]) + 1.000000e+00f);\n", + " B[(((((int)blockIdx.x) * 512) + ((int)threadIdx.x)) + 256)] = (((float)A[(((((int)blockIdx.x) * 512) + ((int)threadIdx.x)) + 256)]) + 1.000000e+00f);\n", + " B[(((((int)blockIdx.x) * 512) + ((int)threadIdx.x)) + 384)] = (((float)A[(((((int)blockIdx.x) * 512) + ((int)threadIdx.x)) + 384)]) + 1.000000e+00f);\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "# get generated cuda source\n", + "print(best.code)" + ] + }, + { + "cell_type": "markdown", + "id": "e59a8f78-9a32-4aa6-9d04-e04ee08f61f6", + "metadata": {}, + "source": [ + "## 2. Gemm tuning with Tensor Core \n", + "Tensor Expression: $C[m, n] = A[m, k] * B[n, k]$" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0a0411df-bafa-4dd5-b61d-85cf79911586", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def main(A: T.Buffer((16384, 16384), \"float16\"), B: T.Buffer((16384, 16384), \"float16\"), C: T.Buffer((16384, 16384), \"float16\")):\n", + " T.func_attr({\"tir.noalias\": T.bool(True)})\n", + " # with T.block(\"root\"):\n", + " for i, j, k in T.grid(16384, 16384, 16384):\n", + " with T.block(\"C\"):\n", + " v_i, v_j, v_k = T.axis.remap(\"SSR\", [i, j, k])\n", + " T.reads(A[v_i, v_k], B[v_j, v_k])\n", + " T.writes(C[v_i, v_j])\n", + " with T.init():\n", + " C[v_i, v_j] = T.float16(0)\n", + " C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_j, v_k]\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "# Describe the matrix multiplication in TE\n", + "A = te.placeholder((M, K), name=\"A\", dtype=\"float16\")\n", + "B = te.placeholder((N, K), name=\"B\", dtype=\"float16\")\n", + "\n", + "k = te.reduce_axis((0, K), name=\"k\")\n", + "C = te.compute(\n", + " (M, N),\n", + " lambda i, j: te.sum(A[i, k].astype(\"float16\") * B[j, k].astype(\"float16\"), axis=k),\n", + " name=\"C\",\n", + ")\n", + "args = [A, B, C]\n", + "func = te.create_prim_func(args)\n", + "print(func)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a1910bda-7a58-47ac-a76b-3902366fdd62", + "metadata": {}, + "outputs": [], + "source": [ + "from bitblas.base.roller.policy import TensorCorePolicy\n", + "from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "7aa61ed1-22a4-424b-ac19-a1f67e0b2bb5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [256, 256], 'warp': [128, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [256, 32], 'warp': [128, 16], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [64, 512], 'warp': [32, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [512, 64], 'warp': [256, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [128, 512], 'warp': [64, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [512, 128], 'warp': [256, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "{'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n" + ] + } + ], + "source": [ + "tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)\n", + "policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)\n", + "configs = policy.emit_config(topk=20)\n", + "\n", + "for config in configs:\n", + " print(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "4b26647f-6d65-496b-a442-8463fdacf24c", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [256, 256], 'warp': [128, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [256, 32], 'warp': [128, 16], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [64, 512], 'warp': [32, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [512, 64], 'warp': [256, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [128, 512], 'warp': [64, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [512, 128], 'warp': [256, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:31:32 [BitBLAS:DEBUG]: Apply config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:32:29 [BitBLAS:INFO]: Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:32:29 [BitBLAS:INFO]: Time cost of this config: 32.161 ms\n", + "2024-10-24 13:32:39 [BitBLAS:INFO]: Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:32:39 [BitBLAS:INFO]: Time cost of this config: 43.725 ms\n", + "2024-10-24 13:32:48 [BitBLAS:INFO]: Evaluation with config {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:32:48 [BitBLAS:INFO]: Time cost of this config: 43.864 ms\n", + "2024-10-24 13:32:57 [BitBLAS:INFO]: Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:32:57 [BitBLAS:INFO]: Time cost of this config: 46.601 ms\n", + "2024-10-24 13:33:06 [BitBLAS:INFO]: Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:33:06 [BitBLAS:INFO]: Time cost of this config: 60.877 ms\n", + "2024-10-24 13:33:16 [BitBLAS:INFO]: Evaluation with config {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:33:16 [BitBLAS:INFO]: Time cost of this config: 30.349 ms\n", + "2024-10-24 13:33:25 [BitBLAS:INFO]: Evaluation with config {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:33:25 [BitBLAS:INFO]: Time cost of this config: 80.994 ms\n", + "2024-10-24 13:33:34 [BitBLAS:INFO]: Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:33:34 [BitBLAS:INFO]: Time cost of this config: 44.942 ms\n", + "2024-10-24 13:33:44 [BitBLAS:INFO]: Evaluation with config {'block': [256, 256], 'warp': [128, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:33:44 [BitBLAS:INFO]: Time cost of this config: 69.066 ms\n", + "2024-10-24 13:33:53 [BitBLAS:INFO]: Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:33:53 [BitBLAS:INFO]: Time cost of this config: 30.514 ms\n", + "2024-10-24 13:34:03 [BitBLAS:INFO]: Evaluation with config {'block': [256, 32], 'warp': [128, 16], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:03 [BitBLAS:INFO]: Time cost of this config: 70.178 ms\n", + "2024-10-24 13:34:12 [BitBLAS:INFO]: Evaluation with config {'block': [64, 512], 'warp': [32, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:12 [BitBLAS:INFO]: Time cost of this config: 48.920 ms\n", + "2024-10-24 13:34:21 [BitBLAS:INFO]: Evaluation with config {'block': [512, 64], 'warp': [256, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:21 [BitBLAS:INFO]: Time cost of this config: 37.340 ms\n", + "2024-10-24 13:34:31 [BitBLAS:INFO]: Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:31 [BitBLAS:INFO]: Time cost of this config: 73.770 ms\n", + "2024-10-24 13:34:40 [BitBLAS:INFO]: Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:40 [BitBLAS:INFO]: Time cost of this config: 74.823 ms\n", + "2024-10-24 13:34:50 [BitBLAS:INFO]: Evaluation with config {'block': [128, 512], 'warp': [64, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:50 [BitBLAS:INFO]: Time cost of this config: 78.612 ms\n", + "2024-10-24 13:34:59 [BitBLAS:INFO]: Evaluation with config {'block': [512, 128], 'warp': [256, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:34:59 [BitBLAS:INFO]: Time cost of this config: 71.001 ms\n", + "2024-10-24 13:35:08 [BitBLAS:INFO]: Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:35:08 [BitBLAS:INFO]: Time cost of this config: 108.881 ms\n", + "2024-10-24 13:35:18 [BitBLAS:INFO]: Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:35:18 [BitBLAS:INFO]: Time cost of this config: 84.179 ms\n", + "2024-10-24 13:35:27 [BitBLAS:INFO]: Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}}\n", + "2024-10-24 13:35:27 [BitBLAS:INFO]: Time cost of this config: 83.798 ms\n" + ] + } + ], + "source": [ + "cpresults, best = apply_and_build(func, configs, arch, parallel_build=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b6a83a8f-359f-4777-8505-67639a91f687", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) {\n", + " const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y;\n", + " const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x);\n", + " const auto totalBlock = gridDim.x * gridDim.y;\n", + " const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x);\n", + " const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x;\n", + " const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd;\n", + " const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width;\n", + " const auto bz = blockIdx.z;\n", + " \n", + " dim3 blockIdx(bx, by, bz);\n", + " return blockIdx;\n", + "}\n", + " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n", + "\n", + "typedef unsigned short uint16_t;\n", + "typedef unsigned char uint8_t;\n", + "typedef signed char int8_t;\n", + "typedef int int32_t;\n", + "typedef unsigned long long uint64_t;\n", + "typedef unsigned int uint32_t;\n", + "\n", + "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n", + "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n", + "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n", + "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n", + " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " }\n", + "\n", + "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const T& a) { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " }\n", + "\n", + "class TVM_ALIGNED(2) half {\n", + " public:\n", + " uint16_t half_;\n", + "\n", + " static TVM_XINLINE half Binary(uint16_t value) {\n", + " half res;\n", + " res.half_ = value;\n", + " return res;\n", + " }\n", + "\n", + " TVM_XINLINE half() {}\n", + "\n", + " TVM_XINLINE half(const float& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const long long& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n", + "\n", + " TVM_XINLINE operator float() const { \\\n", + " return float(half2float(half_)); \\\n", + " } \\\n", + " TVM_XINLINE operator float() const volatile { \\\n", + " return float(half2float(half_)); \\\n", + " }\n", + "\n", + "\n", + " TVM_HALF_ASSIGNOP(+=, +)\n", + " TVM_HALF_ASSIGNOP(-=, -)\n", + " TVM_HALF_ASSIGNOP(*=, *)\n", + " TVM_HALF_ASSIGNOP(/=, /)\n", + "\n", + " TVM_XINLINE half operator+() {\n", + " return *this;\n", + " }\n", + "\n", + " TVM_XINLINE half operator-() {\n", + " return half(-float(*this));\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) volatile {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) volatile {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " private:\n", + " union Bits {\n", + " float f;\n", + " int32_t si;\n", + " uint32_t ui;\n", + " };\n", + "\n", + " static int const fp16FractionBits = 10;\n", + " static int const fp32FractionBits = 23;\n", + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n", + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n", + " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n", + " static int const shiftSign = 16;\n", + " static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n", + "\n", + " static int32_t const infN = 0x7F800000; // flt32 infinity\n", + " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n", + " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n", + " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n", + " static int32_t const signN = 0x80000000; // flt32 sign bit\n", + "\n", + " static int32_t const infC = infN >> shift;\n", + " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n", + " static int32_t const maxC = maxN >> shift;\n", + " static int32_t const minC = minN >> shift;\n", + " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n", + "\n", + " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n", + " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n", + "\n", + " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n", + " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n", + "\n", + " static int32_t const maxD = infC - maxC - 1;\n", + " static int32_t const minD = minC - subC - 1;\n", + "\n", + " TVM_XINLINE uint16_t float2half(const float& value) const {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " // Same as above routine, except for addition of volatile keyword\n", + " TVM_XINLINE uint16_t float2half(\n", + " const volatile float& value) const volatile {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(const uint16_t& value) const {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(\n", + " const volatile uint16_t& value) const volatile {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE void constructor(const T& value) {\n", + " half_ = float2half(float(value));\n", + " }\n", + "};\n", + "\n", + "TVM_HALF_OPERATOR(half, +)\n", + "TVM_HALF_OPERATOR(half, -)\n", + "TVM_HALF_OPERATOR(half, *)\n", + "TVM_HALF_OPERATOR(half, /)\n", + "TVM_HALF_OPERATOR(bool, >)\n", + "TVM_HALF_OPERATOR(bool, <)\n", + "TVM_HALF_OPERATOR(bool, >=)\n", + "TVM_HALF_OPERATOR(bool, <=)\n", + "\n", + "TVM_XINLINE half __float2half_rn(const float a) {\n", + " return half(a);\n", + "}\n", + "#else\n", + "#include \n", + "__device__ half max(half a, half b)\n", + "{\n", + " return __hgt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "__device__ half min(half a, half b)\n", + "{\n", + " return __hlt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "#endif\n", + "\n", + "\n", + "// Pack two half values.\n", + "static inline __device__ __host__ unsigned\n", + "__pack_half2(const half x, const half y) {\n", + " unsigned v0 = *((unsigned short *)&x);\n", + " unsigned v1 = *((unsigned short *)&y);\n", + " return (v1 << 16) | v0;\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float tmp_y = __half2float(y); \\\n", + " float result = FP32_MATH_NAME(tmp_x, tmp_y); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float result = FP32_MATH_NAME(tmp_x); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "// Some fp16 math functions are not supported in cuda_fp16.h,\n", + "// so we define them here to make sure the generated CUDA code\n", + "// is valid.\n", + "#if defined(__CUDA_ARCH__)\n", + "#if (__CUDA_ARCH__ >= 530)\n", + "CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)\n", + "#else\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)\n", + "#endif\n", + "#endif\n", + "\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY\n", + "\n", + "struct __align__(8) half4 {\n", + " __half x, y, z, w;\n", + " __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}\n", + " __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}\n", + "\n", + "};\n", + "__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {\n", + " return half4(x, y, z, w);\n", + "}\n", + "__forceinline__ __device__ unsigned int\n", + "cast_smem_ptr_to_int(const void* const smem_ptr)\n", + "{\n", + " unsigned int smem_int;\n", + " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; cvt.u32.u64 %0, smem_int; }\"\n", + " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n", + " return smem_int;\n", + "}\n", + "\n", + "#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n", + " (__CUDACC_VER_MAJOR__ > 11))\n", + "#define TVM_ENABLE_L2_PREFETCH 1\n", + "#else\n", + "#define TVM_ENABLE_L2_PREFETCH 0\n", + "#endif\n", + "\n", + "#ifdef _WIN32\n", + " using uint = unsigned int;\n", + " using uchar = unsigned char;\n", + " using ushort = unsigned short;\n", + " using int64_t = long long;\n", + " using uint64_t = unsigned long long;\n", + "#else\n", + " #define uint unsigned int\n", + " #define uchar unsigned char\n", + " #define ushort unsigned short\n", + " #define int64_t long long\n", + " #define uint64_t unsigned long long\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) \n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1\n", + "#else\n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0\n", + "#endif\n", + "extern \"C\" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C);\n", + "extern \"C\" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_shared_dyn_warp[256];\n", + " half A_reindex_shared_dyn_warp[64];\n", + " half B_reindex_shared_dyn_warp[32];\n", + " for (int var = 0; var < 1; ++var) {\n", + "\n", + " const dim3 blockIdx = rasterization2DColumn(11);\n", + " for (int ax1_0_3_init = 0; ax1_0_3_init < 8; ++ax1_0_3_init) {\n", + " for (int ax2_0_3_init = 0; ax2_0_3_init < 4; ++ax2_0_3_init) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_shared_dyn_warp[((ax1_0_3_init * 32) + (ax2_0_3_init * 8)) + i] = 0.0;}\n", + ";\n", + " }\n", + " }\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 512; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_2 = 0; ax0_ax1_ax2_fused_2 < 8; ++ax0_ax1_ax2_fused_2) {\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((((((int)threadIdx.y) * 4096) + (((int)threadIdx.z) * 2048)) + (ax0_ax1_ax2_fused_2 * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8))) = *(uint4*)(A + (((((((((int)blockIdx.y) * 4194304) + (((int)threadIdx.y) * 2097152)) + (((int)threadIdx.z) * 1048576)) + (ax0_ax1_ax2_fused_2 * 131072)) + ((((int)threadIdx.x) >> 2) * 16384)) + (ax3_0_0 * 32)) + ((((int)threadIdx.x) & 3) * 8)));\n", + " }\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_2_1 = 0; ax0_ax1_ax2_fused_2_1 < 4; ++ax0_ax1_ax2_fused_2_1) {\n", + " *(uint4*)(((half*)buf_dyn_shmem) + ((((((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 1024)) + (ax0_ax1_ax2_fused_2_1 * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8)) + 8192)) = *(uint4*)(B + (((((((((int)blockIdx.x) * 2097152) + (((int)threadIdx.y) * 1048576)) + (((int)threadIdx.z) * 524288)) + (ax0_ax1_ax2_fused_2_1 * 131072)) + ((((int)threadIdx.x) >> 2) * 16384)) + (ax3_0_0 * 32)) + ((((int)threadIdx.x) & 3) * 8)));\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 2; ++ax3_0_1) {\n", + " for (int ax0_0 = 0; ax0_0 < 8; ++ax0_0) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((((((int)threadIdx.y) * 4096) + (ax0_0 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0)));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((((((int)threadIdx.y) * 4096) + (ax0_0 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_shared_dyn_warp + (ax0_0 * 8)))[0]), \"=r\"(((unsigned *)(A_reindex_shared_dyn_warp + (ax0_0 * 8)))[1]), \"=r\"(((unsigned *)(A_reindex_shared_dyn_warp + (ax0_0 * 8)))[2]), \"=r\"(((unsigned *)(A_reindex_shared_dyn_warp + (ax0_0 * 8)))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + " }\n", + " for (int ax0_0_1 = 0; ax0_0_1 < 4; ++ax0_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((((((((int)threadIdx.z) * 2048) + (ax0_0_1 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 8192)])) + 0)));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((((((((int)threadIdx.z) * 2048) + (ax0_0_1 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 8192)])) + 0))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_reindex_shared_dyn_warp + (ax0_0_1 * 8)))[0]), \"=r\"(((unsigned *)(B_reindex_shared_dyn_warp + (ax0_0_1 * 8)))[1]), \"=r\"(((unsigned *)(B_reindex_shared_dyn_warp + (ax0_0_1 * 8)))[2]), \"=r\"(((unsigned *)(B_reindex_shared_dyn_warp + (ax0_0_1 * 8)))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + " }\n", + " for (int ax1_0_3 = 0; ax1_0_3 < 8; ++ax1_0_3) {\n", + " for (int ax2_0_3 = 0; ax2_0_3 < 4; ++ax2_0_3) {\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[0]), \"=r\"(((unsigned *)(C_reindex_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[0]), \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[1]), \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[2]), \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[3]), \"r\"(((unsigned *)((half*)B_reindex_shared_dyn_warp + (ax2_0_3 * 8)))[0]), \"r\"(((unsigned *)((half*)B_reindex_shared_dyn_warp + (ax2_0_3 * 8)))[1]), \"r\"(((unsigned *)(C_reindex_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[0]), \"r\"(((unsigned *)(C_reindex_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[0]), \"=r\"(((unsigned *)(C_reindex_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[0]), \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[1]), \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[2]), \"r\"(((unsigned *)((half*)A_reindex_shared_dyn_warp + (ax1_0_3 * 8)))[3]), \"r\"(((unsigned *)((half*)B_reindex_shared_dyn_warp + ((ax2_0_3 * 8) + 4)))[0]), \"r\"(((unsigned *)((half*)B_reindex_shared_dyn_warp + ((ax2_0_3 * 8) + 4)))[1]), \"r\"(((unsigned *)(C_reindex_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[0]), \"r\"(((unsigned *)(C_reindex_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[1]));\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " for (int ax0 = 0; ax0 < 8; ++ax0) {\n", + " __syncthreads();\n", + " for (int ax1 = 0; ax1 < 4; ++ax1) {\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[(((((int)threadIdx.y) * 16384) + (((int)threadIdx.z) * 1024)) + (ax1 * 256))]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_shared_dyn_warp[((ax0 * 32) + (ax1 * 8)) + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 4; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " *(uint4*)(C + ((((((((((int)blockIdx.y) * 4194304) + (((int)threadIdx.y) * 2097152)) + (ax0 * 262144)) + ((((int)threadIdx.x) >> 1) * 16384)) + (((int)blockIdx.x) * 128)) + (((int)threadIdx.z) * 64)) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + ((((((int)threadIdx.y) * 16384) + (((int)threadIdx.z) * 1024)) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 256)) + (((int)threadIdx.x) * 8)));\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "# get generated cuda source\n", + "print(best.code)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/2.auto_tensorization.ipynb b/tutorials/2.auto_tensorization.ipynb new file mode 100644 index 000000000..ebf072959 --- /dev/null +++ b/tutorials/2.auto_tensorization.ipynb @@ -0,0 +1,251 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "a3efea70-191e-40a1-abe6-c8aa0c4535c4", + "metadata": {}, + "source": [ + "# Auto Tensorization in BitBLAS\n", + "\n", + "Auto detect whether a given operator (gemm, conv2d, stencil, etc.) can be tensorized with given instructions' computation flow (MMA, DP4A, etc.)\n", + "\n", + "![image.png](./img/AutoTensorization.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7f17ee04-4406-4948-98e9-f61a42ed563d", + "metadata": {}, + "outputs": [], + "source": [ + "import bitblas\n", + "from bitblas import tvm\n", + "from tvm import te, tir" + ] + }, + { + "cell_type": "markdown", + "id": "aa6c2458-7b4f-4951-addc-480df5fd9ef2", + "metadata": {}, + "source": [ + "Get a convlution expression" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3d3016d6-57c0-40d9-bee0-72582a3c9365", + "metadata": {}, + "outputs": [], + "source": [ + "def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype=\"float16\", out_dtype=\"float16\"):\n", + " A = te.placeholder((n, h, w, c), name=\"input\", dtype=in_dtype)\n", + " B = te.placeholder((kh, kw, c, f), name=\"weight\", dtype=in_dtype)\n", + "\n", + " pad_shape = (n, h + 2 * p, w + 2 * p, c)\n", + " pad_value = tir.const(0.0, A.dtype)\n", + " pad = te.compute(\n", + " pad_shape,\n", + " lambda n, h, w, c: te.if_then_else(\n", + " tir.all(\n", + " h >= p,\n", + " w >= p,\n", + " h < pad_shape[1] - p,\n", + " w < pad_shape[2] - p,\n", + " ),\n", + " A[n, h - p, w - p, c],\n", + " pad_value,\n", + " ),\n", + " name=\"pad\",\n", + " )\n", + " kernel_h, kernel_w = kh, kw\n", + " stride_h, stride_w = s, s\n", + " dilation_h, dilation_w = d, d\n", + " out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1\n", + " out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1\n", + " out_shape = (n, out_h, out_w, f)\n", + " kh = te.reduce_axis((0, kernel_h), name=\"kh\")\n", + " kw = te.reduce_axis((0, kernel_w), name=\"kw\")\n", + " c = te.reduce_axis((0, c), name=\"c\")\n", + " C = te.compute(\n", + " out_shape,\n", + " lambda n, h, w, f: te.sum(\n", + " pad[\n", + " n,\n", + " h * stride_h + kh * dilation_h,\n", + " w * stride_w + kw * dilation_w,\n", + " c,\n", + " ] * B[kh, kw, c, f],\n", + " axis=[kh, kw, c],\n", + " ),\n", + " name=\"C\",\n", + " )\n", + " return te.create_prim_func([A, B, C])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8cd2f6ce-69b9-47b7-b466-790ea4712c3a", + "metadata": {}, + "outputs": [], + "source": [ + "func = conv2d_nhwc_hwio(128, 64, 224, 224, 64, 1, 1, 2, 1, 3, \"float16\", \"float16\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a2b1e4e2-8a34-437f-a9f4-72d59a7eef33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def main(input: T.Buffer((128, 224, 224, 64), \"float16\"), weight: T.Buffer((1, 1, 64, 64), \"float16\"), C: T.Buffer((128, 115, 115, 64), \"float16\")):\n", + " T.func_attr({\"tir.noalias\": T.bool(True)})\n", + " # with T.block(\"root\"):\n", + " pad = T.alloc_buffer((128, 230, 230, 64), \"float16\")\n", + " for n, h, w, c in T.grid(128, 230, 230, 64):\n", + " with T.block(\"pad\"):\n", + " v_n, v_h, v_w, v_c = T.axis.remap(\"SSSS\", [n, h, w, c])\n", + " T.reads(input[v_n, v_h - 3, v_w - 3, v_c])\n", + " T.writes(pad[v_n, v_h, v_w, v_c])\n", + " pad[v_n, v_h, v_w, v_c] = T.if_then_else(3 <= v_h and 3 <= v_w and v_h < 227 and v_w < 227, input[v_n, v_h - 3, v_w - 3, v_c], T.float16(0))\n", + " for n, h, w, f, kh, kw, c in T.grid(128, 115, 115, 64, 1, 1, 64):\n", + " with T.block(\"C\"):\n", + " v_n, v_h, v_w, v_f, v_kh, v_kw, v_c = T.axis.remap(\"SSSSRRR\", [n, h, w, f, kh, kw, c])\n", + " T.reads(pad[v_n, v_h * 2 + v_kh, v_w * 2 + v_kw, v_c], weight[v_kh, v_kw, v_c, v_f])\n", + " T.writes(C[v_n, v_h, v_w, v_f])\n", + " with T.init():\n", + " C[v_n, v_h, v_w, v_f] = T.float16(0)\n", + " C[v_n, v_h, v_w, v_f] = C[v_n, v_h, v_w, v_f] + pad[v_n, v_h * 2 + v_kh, v_w * 2 + v_kw, v_c] * weight[v_kh, v_kw, v_c, v_f]\n" + ] + } + ], + "source": [ + "print(func)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3615d7cd-7e81-4c67-91ce-1fe922fe11f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "target='nvidia/geforce-rtx-4090'\n" + ] + } + ], + "source": [ + "from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags\n", + "from bitblas.base.arch import CUDA\n", + "\n", + "target = bitblas.auto_detect_nvidia_target()\n", + "print(f\"{target=}\")\n", + "arch = CUDA(target)" + ] + }, + { + "cell_type": "markdown", + "id": "981e0121-95fa-46d5-8c94-6992709f64f5", + "metadata": {}, + "source": [ + "## Get Tensorized Function" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "fcf8597d-9131-4441-b612-782ed9d66a13", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def main(input: T.Buffer((128, 224, 224, 64), \"float16\"), weight: T.Buffer((1, 1, 64, 64), \"float16\"), C: T.Buffer((128, 115, 115, 64), \"float16\")):\n", + " T.func_attr({\"dlight.tensorcore_prenormlized\": T.bool(True), \"tir.noalias\": T.bool(True)})\n", + " # with T.block(\"root\"):\n", + " pad = T.alloc_buffer((128, 230, 230, 64), \"float16\")\n", + " pad_reindex = T.alloc_buffer((1, 1692800, 64), \"float16\")\n", + " weight_reindex = T.alloc_buffer((1, 64, 64), \"float16\")\n", + " C_reindex = T.alloc_buffer((1, 1692800, 64), \"float16\")\n", + " for n, h, w, c in T.grid(128, 230, 230, 64):\n", + " with T.block(\"pad\"):\n", + " v_n, v_h, v_w, v_c = T.axis.remap(\"SSSS\", [n, h, w, c])\n", + " T.reads(input[v_n, v_h - 3, v_w - 3, v_c])\n", + " T.writes(pad[v_n, v_h, v_w, v_c])\n", + " pad[v_n, v_h, v_w, v_c] = T.if_then_else(3 <= v_h and 3 <= v_w and v_h < 227 and v_w < 227, input[v_n, v_h - 3, v_w - 3, v_c], T.float16(0))\n", + " for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(128, 115, 115, 1, 1, 64):\n", + " with T.block(\"pad_reindex_reindex\"):\n", + " v0, v1, v2, v3, v4, v5 = T.axis.remap(\"SSSSSS\", [ax0, ax1, ax2, ax3, ax4, ax5])\n", + " T.reads(pad[v0, v1 * 2 + v3, v2 * 2 + v4, v5])\n", + " T.writes(pad_reindex[0, v0 * 13225 + v1 * 115 + v2, v5])\n", + " pad_reindex[0, v0 * 13225 + v1 * 115 + v2, v5] = pad[v0, v1 * 2 + v3, v2 * 2 + v4, v5]\n", + " for ax0, ax1, ax2, ax3 in T.grid(64, 1, 1, 64):\n", + " with T.block(\"weight_reindex_reindex\"):\n", + " v0, v1, v2, v3 = T.axis.remap(\"SSSS\", [ax0, ax1, ax2, ax3])\n", + " T.reads(weight[v1, v2, v3, v0])\n", + " T.writes(weight_reindex[0, v3, v0])\n", + " weight_reindex[0, v3, v0] = weight[v1, v2, v3, v0]\n", + " for ax0, ax1, ax2, ax3 in T.grid(1, 1692800, 64, 64):\n", + " with T.block(\"C\"):\n", + " v0, v1, v2, v3 = T.axis.remap(\"SSSR\", [ax0, ax1, ax2, ax3])\n", + " T.reads(pad_reindex[0, v1, v3], weight_reindex[0, v3, v2])\n", + " T.writes(C_reindex[0, v1, v2])\n", + " with T.init():\n", + " C_reindex[0, v1, v2] = T.float16(0)\n", + " C_reindex[0, v1, v2] = C_reindex[0, v1, v2] + pad_reindex[0, v1, v3] * weight_reindex[0, v3, v2]\n", + " for ax0, ax1, ax2, ax3 in T.grid(128, 115, 115, 64):\n", + " with T.block(\"C_reindex\"):\n", + " v0, v1, v2, v3 = T.axis.remap(\"SSSS\", [ax0, ax1, ax2, ax3])\n", + " T.reads(C_reindex[0, v0 * 13225 + v1 * 115 + v2, v3])\n", + " T.writes(C[v0, v1, v2, v3])\n", + " C[v0, v1, v2, v3] = C_reindex[0, v0 * 13225 + v1 * 115 + v2, v3]\n" + ] + } + ], + "source": [ + "tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)\n", + "print(tensorized_func)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/3.fast_decoding.ipynb b/tutorials/3.fast_decoding.ipynb new file mode 100644 index 000000000..f8b8b093e --- /dev/null +++ b/tutorials/3.fast_decoding.ipynb @@ -0,0 +1,775 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "9bc8f086-5dda-4aee-9dcb-14e4d9720f0d", + "metadata": {}, + "source": [ + "# Fast Dequantizatoin\n", + "\n", + "How to enbale fast dequantization (INT4/2/1 -> FP16/INT8) or (FP8 -> FP16)?\n", + "\n", + "![image.png](./img/FastDequantization.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "32b3eb3a-a0a2-4c63-9e24-a248c1fe3579", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-24 14:35:48 [BitBLAS:INFO]: Auto detected target: nvidia/geforce-rtx-4090\n", + "2024-10-24 14:35:49 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:36:16 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:36:17 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-24 14:36:17 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-24 14:36:42 [BitBLAS:DEBUG]: LocalBuilder: An exception occurred Traceback (most recent call last):\n", + " File \"/root/BitBLAS/3rdparty/tvm/python/tvm/exec/popen_worker.py\", line 87, in main\n", + " result = fn(*args, **kwargs)\n", + " File \"/root/BitBLAS/bitblas/base/utils.py\", line 257, in _build\n", + " rt_mod = tvm.build(mod, ta\t...\tm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::RampNode const*, std::ostream&)\n", + " File \"/root/BitBLAS/3rdparty/tvm/src/target/source/codegen_cuda.cc\", line 1226\n", + "ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.\n", + "\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.004 ms\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.004 ms\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.004 ms\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.004 ms\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Evaluation with config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:36:42 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "Ref output: tensor([[1586., 1519., 1561., ..., 1566., 1536., 1561.]], device='cuda:0',\n", + " dtype=torch.float16)\n", + "BitBLAS output: tensor([[1586., 1518., 1562., ..., 1567., 1535., 1562.]], device='cuda:0',\n", + " dtype=torch.float16)\n" + ] + } + ], + "source": [ + "import bitblas\n", + "import torch\n", + "\n", + "# enabling debug output\n", + "\n", + "bitblas.set_log_level(\"Debug\")\n", + "matmul_config = bitblas.MatmulConfig(\n", + " M=1, # M dimension\n", + " N=1024, # N dimension\n", + " K=1024, # K dimension\n", + " A_dtype=\"float16\", # activation A dtype\n", + " W_dtype=\"int4\", # weight W dtype\n", + " accum_dtype=\"float16\", # accumulation dtype\n", + " out_dtype=\"float16\", # output dtype\n", + " layout=\"nt\", # matrix layout, \"nt\" indicates the layout of A is non-transpose and the layout of W is transpose\n", + " with_bias=False, # bias\n", + " # configs for weight only quantization\n", + " group_size=None, # setting for grouped quantization\n", + " with_scaling=False, # setting for scaling factor\n", + " with_zeros=False, # setting for zeros\n", + " zeros_mode=None, # setting for how to calculating zeros\n", + ")\n", + "\n", + "matmul = bitblas.Matmul(config=matmul_config)\n", + "\n", + "# Create input matrices\n", + "input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()\n", + "weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda()\n", + "\n", + "# Transform weight tensor to int4 data type\n", + "weight_tensor_int4 = matmul.transform_weight(weight_tensor)\n", + "\n", + "# Perform mixed-precision matrix multiplication\n", + "output_tensor = matmul(input_tensor, weight_tensor_int4)\n", + "\n", + "# Reference result using PyTorch matmul for comparison\n", + "ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))\n", + "# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0\n", + "print(\"Ref output:\", ref_result)\n", + "print(\"BitBLAS output:\", output_tensor)\n", + "torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "839bd993-0120-4061-9799-88b55f905ce1", + "metadata": {}, + "outputs": [], + "source": [ + "import bitblas\n", + "import torch\n", + "\n", + "# enabling debug output\n", + "bitblas.set_log_level(\"Debug\")\n", + "matmul_config = bitblas.MatmulConfig(\n", + " M=1, # M dimension\n", + " N=16384, # N dimension\n", + " K=16384, # K dimension\n", + " A_dtype=\"float16\", # activation A dtype\n", + " W_dtype=\"float16\", # weight W dtype\n", + ")\n", + "\n", + "matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)\n", + "\n", + "matmul.hardware_aware_finetune(topk=20, parallel_build=True)\n", + "\n", + "latency = matmul.profile_latency()\n", + "print(f\"{latency=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "832054b0-35a6-4422-b008-c6759e3a9162", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-24 14:45:08 [BitBLAS:INFO]: Auto detected target: nvidia/geforce-rtx-4090\n", + "2024-10-24 14:45:08 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-24 14:45:10 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [2048], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [4096], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [4096], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-24 14:45:11 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n" + ] + } + ], + "source": [ + "matmul_config = bitblas.MatmulConfig(\n", + " M=1, # M dimension\n", + " N=16384, # N dimension\n", + " K=16384, # K dimension\n", + " A_dtype=\"float16\", # activation A dtype\n", + " W_dtype=\"int4\", # weight W dtype\n", + " fast_decoding=False, # Disable Fast Decoding\n", + ")\n", + "\n", + "matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)\n", + "\n", + "matmul.hardware_aware_finetune(topk=20, parallel_build=True)\n", + "\n", + "print(matmul.get_source())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5da4e137-5d02-46e7-9cb0-185d8d0137f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "latency=0.1433375\n" + ] + } + ], + "source": [ + "latency = matmul.profile_latency()\n", + "print(f\"{latency=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b7b4a317-9550-42df-8c3e-be88069eabf8", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-24 14:44:22 [BitBLAS:INFO]: Auto detected target: nvidia/geforce-rtx-4090\n", + "2024-10-24 14:44:22 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [2048], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [4096], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [4096], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-24 14:44:25 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-24 14:44:51 [BitBLAS:DEBUG]: LocalBuilder: An exception occurred Traceback (most recent call last):\n", + " File \"/root/BitBLAS/3rdparty/tvm/python/tvm/exec/popen_worker.py\", line 87, in main\n", + " result = fn(*args, **kwargs)\n", + " File \"/root/BitBLAS/bitblas/base/utils.py\", line 257, in _build\n", + " rt_mod = tvm.build(mod, ta\t...\tm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::RampNode const*, std::ostream&)\n", + " File \"/root/BitBLAS/3rdparty/tvm/src/target/source/codegen_cuda.cc\", line 1226\n", + "ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.\n", + "\n", + "2024-10-24 14:44:52 [BitBLAS:INFO]: Evaluation with config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-24 14:44:52 [BitBLAS:INFO]: Time cost of this config: 0.151 ms\n", + "2024-10-24 14:44:53 [BitBLAS:INFO]: Evaluation with config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-24 14:44:53 [BitBLAS:INFO]: Time cost of this config: 0.144 ms\n", + "2024-10-24 14:44:54 [BitBLAS:INFO]: Evaluation with config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:54 [BitBLAS:INFO]: Time cost of this config: 0.161 ms\n", + "2024-10-24 14:44:55 [BitBLAS:INFO]: Evaluation with config {'block': [4], 'thread': [4], 'rstep': [2048], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:55 [BitBLAS:INFO]: Time cost of this config: 0.145 ms\n", + "2024-10-24 14:44:57 [BitBLAS:INFO]: Evaluation with config {'block': [2], 'thread': [2], 'rstep': [4096], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:57 [BitBLAS:INFO]: Time cost of this config: 0.628 ms\n", + "2024-10-24 14:44:58 [BitBLAS:INFO]: Evaluation with config {'block': [1], 'thread': [1], 'rstep': [4096], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-24 14:44:58 [BitBLAS:INFO]: Time cost of this config: 0.305 ms\n", + "2024-10-24 14:44:59 [BitBLAS:INFO]: Evaluation with config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}\n", + "2024-10-24 14:44:59 [BitBLAS:INFO]: Time cost of this config: 0.144 ms\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n", + "\n", + "typedef unsigned short uint16_t;\n", + "typedef unsigned char uint8_t;\n", + "typedef signed char int8_t;\n", + "typedef int int32_t;\n", + "typedef unsigned long long uint64_t;\n", + "typedef unsigned int uint32_t;\n", + "\n", + "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n", + "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n", + "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n", + "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n", + " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " }\n", + "\n", + "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const T& a) { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " }\n", + "\n", + "class TVM_ALIGNED(2) half {\n", + " public:\n", + " uint16_t half_;\n", + "\n", + " static TVM_XINLINE half Binary(uint16_t value) {\n", + " half res;\n", + " res.half_ = value;\n", + " return res;\n", + " }\n", + "\n", + " TVM_XINLINE half() {}\n", + "\n", + " TVM_XINLINE half(const float& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const long long& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n", + "\n", + " TVM_XINLINE operator float() const { \\\n", + " return float(half2float(half_)); \\\n", + " } \\\n", + " TVM_XINLINE operator float() const volatile { \\\n", + " return float(half2float(half_)); \\\n", + " }\n", + "\n", + "\n", + " TVM_HALF_ASSIGNOP(+=, +)\n", + " TVM_HALF_ASSIGNOP(-=, -)\n", + " TVM_HALF_ASSIGNOP(*=, *)\n", + " TVM_HALF_ASSIGNOP(/=, /)\n", + "\n", + " TVM_XINLINE half operator+() {\n", + " return *this;\n", + " }\n", + "\n", + " TVM_XINLINE half operator-() {\n", + " return half(-float(*this));\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) volatile {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) volatile {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " private:\n", + " union Bits {\n", + " float f;\n", + " int32_t si;\n", + " uint32_t ui;\n", + " };\n", + "\n", + " static int const fp16FractionBits = 10;\n", + " static int const fp32FractionBits = 23;\n", + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n", + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n", + " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n", + " static int const shiftSign = 16;\n", + " static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n", + "\n", + " static int32_t const infN = 0x7F800000; // flt32 infinity\n", + " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n", + " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n", + " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n", + " static int32_t const signN = 0x80000000; // flt32 sign bit\n", + "\n", + " static int32_t const infC = infN >> shift;\n", + " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n", + " static int32_t const maxC = maxN >> shift;\n", + " static int32_t const minC = minN >> shift;\n", + " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n", + "\n", + " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n", + " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n", + "\n", + " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n", + " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n", + "\n", + " static int32_t const maxD = infC - maxC - 1;\n", + " static int32_t const minD = minC - subC - 1;\n", + "\n", + " TVM_XINLINE uint16_t float2half(const float& value) const {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " // Same as above routine, except for addition of volatile keyword\n", + " TVM_XINLINE uint16_t float2half(\n", + " const volatile float& value) const volatile {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(const uint16_t& value) const {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(\n", + " const volatile uint16_t& value) const volatile {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE void constructor(const T& value) {\n", + " half_ = float2half(float(value));\n", + " }\n", + "};\n", + "\n", + "TVM_HALF_OPERATOR(half, +)\n", + "TVM_HALF_OPERATOR(half, -)\n", + "TVM_HALF_OPERATOR(half, *)\n", + "TVM_HALF_OPERATOR(half, /)\n", + "TVM_HALF_OPERATOR(bool, >)\n", + "TVM_HALF_OPERATOR(bool, <)\n", + "TVM_HALF_OPERATOR(bool, >=)\n", + "TVM_HALF_OPERATOR(bool, <=)\n", + "\n", + "TVM_XINLINE half __float2half_rn(const float a) {\n", + " return half(a);\n", + "}\n", + "#else\n", + "#include \n", + "__device__ half max(half a, half b)\n", + "{\n", + " return __hgt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "__device__ half min(half a, half b)\n", + "{\n", + " return __hlt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "#endif\n", + "\n", + "\n", + "// Pack two half values.\n", + "static inline __device__ __host__ unsigned\n", + "__pack_half2(const half x, const half y) {\n", + " unsigned v0 = *((unsigned short *)&x);\n", + " unsigned v1 = *((unsigned short *)&y);\n", + " return (v1 << 16) | v0;\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float tmp_y = __half2float(y); \\\n", + " float result = FP32_MATH_NAME(tmp_x, tmp_y); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float result = FP32_MATH_NAME(tmp_x); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "// Some fp16 math functions are not supported in cuda_fp16.h,\n", + "// so we define them here to make sure the generated CUDA code\n", + "// is valid.\n", + "#if defined(__CUDA_ARCH__)\n", + "#if (__CUDA_ARCH__ >= 530)\n", + "CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)\n", + "#else\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)\n", + "#endif\n", + "#endif\n", + "\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY\n", + "\n", + "struct __align__(8) half4 {\n", + " __half x, y, z, w;\n", + " __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}\n", + " __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}\n", + "\n", + "};\n", + "__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {\n", + " return half4(x, y, z, w);\n", + "}\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)\n", + "#define __shfl_sync(mask, var, lane, width) \\\n", + " __shfl((var), (lane), (width))\n", + "\n", + "#define __shfl_down_sync(mask, var, offset, width) \\\n", + " __shfl_down((var), (offset), (width))\n", + "\n", + "#define __shfl_up_sync(mask, var, offset, width) \\\n", + " __shfl_up((var), (offset), (width))\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n", + "#include \n", + "\n", + "\n", + "#if defined(__CUDACC_RTC__)\n", + "#define __SM_61_INTRINSICS_DECL__ __device__\n", + "#else /* !__CUDACC_RTC__ */\n", + "#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__\n", + "#endif /* __CUDACC_RTC__ */\n", + "\n", + "#ifndef __CUDA_ARCH__\n", + "#define __DEF_IF_HOST { }\n", + "#else /* !__CUDA_ARCH__ */\n", + "#define __DEF_IF_HOST ;\n", + "#endif /* __CUDA_ARCH__ */\n", + "\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) __DEF_IF_HOST\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) __DEF_IF_HOST\n", + "\n", + "#undef __DEF_IF_HOST\n", + "\n", + "#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {\n", + " int ret;\n", + " asm volatile (\"dp4a.u32.s32 %0, %1, %2, %3;\" : \"=r\"(ret) : \"r\"(srcA), \"r\"(srcB), \"r\"(c));\n", + " return ret;\n", + "}\n", + "\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {\n", + " int ret;\n", + " asm volatile (\"dp4a.s32.u32 %0, %1, %2, %3;\" : \"=r\"(ret) : \"r\"(srcA), \"r\"(srcB), \"r\"(c));\n", + " return ret;\n", + "}\n", + "#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */\n", + "\n", + "#undef __SM_61_INTRINSICS_DECL__\n", + "\n", + "#endif\n", + "\n", + "#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n", + " (__CUDACC_VER_MAJOR__ > 11))\n", + "#define TVM_ENABLE_L2_PREFETCH 1\n", + "#else\n", + "#define TVM_ENABLE_L2_PREFETCH 0\n", + "#endif\n", + "\n", + "#ifdef _WIN32\n", + " using uint = unsigned int;\n", + " using uchar = unsigned char;\n", + " using ushort = unsigned short;\n", + " using int64_t = long long;\n", + " using uint64_t = unsigned long long;\n", + "#else\n", + " #define uint unsigned int\n", + " #define uchar unsigned char\n", + " #define ushort unsigned short\n", + " #define int64_t long long\n", + " #define uint64_t unsigned long long\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) \n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1\n", + "#else\n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0\n", + "#endif\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_m1n16384k16384_f16xi4_simt_kernel(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C);\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_m1n16384k16384_f16xi4_simt_kernel(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C) {\n", + " half in_thread_C_local[1];\n", + " signed char B_local[4];\n", + " half B_decode_local[8];\n", + " half A_local[8];\n", + " __shared__ half red_result[1];\n", + " in_thread_C_local[0] = __float2half_rn(0.000000e+00f);\n", + " for (int ax1_0 = 0; ax1_0 < 16; ++ax1_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + (((((int)blockIdx.x) * 8192) + (ax1_0 * 512)) + (((int)threadIdx.x) * 4)));\n", + " for (int ax1 = 0; ax1 < 8; ++ax1) {\n", + " B_decode_local[ax1] = (((half)((((uint)B_local[(ax1 >> 1)]) >> (((uint)(ax1 & 1)) * (uint)4)) & (uint)15)) - __float2half_rn(8.000000e+00f));\n", + " }\n", + " *(uint4*)(A_local + 0) = *(uint4*)(A + ((ax1_0 * 1024) + (((int)threadIdx.x) * 8)));\n", + " for (int ax1_2_0 = 0; ax1_2_0 < 4; ++ax1_2_0) {\n", + " for (int ax1_2_1 = 0; ax1_2_1 < 2; ++ax1_2_1) {\n", + " in_thread_C_local[0] = (in_thread_C_local[0] + (A_local[((ax1_2_0 * 2) + ax1_2_1)] * B_decode_local[((ax1_2_0 * 2) + ax1_2_1)]));\n", + " }\n", + " }\n", + " }\n", + " half red_buf0[1];\n", + " uint mask[1];\n", + " half t0[1];\n", + " half red_buf0_1[1];\n", + " uint mask_1[1];\n", + " half t0_1[1];\n", + " __shared__ half red_buf_staging[4];\n", + " red_buf0_1[0] = in_thread_C_local[0];\n", + " mask_1[0] = __activemask();\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " if ((((int)threadIdx.x) % 32) == 0) {\n", + " red_buf_staging[(((int)threadIdx.x) >> 5)] = red_buf0_1[0];\n", + " }\n", + " __syncthreads();\n", + " if (((int)threadIdx.x) < 4) {\n", + " red_buf0[0] = red_buf_staging[((int)threadIdx.x)];\n", + " }\n", + " mask[0] = __activemask();\n", + " t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);\n", + " red_buf0[0] = (red_buf0[0] + t0[0]);\n", + " t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);\n", + " red_buf0[0] = (red_buf0[0] + t0[0]);\n", + " if (((int)threadIdx.x) == 0) {\n", + " ((volatile half*)red_result)[0] = red_buf0[0];\n", + " }\n", + " __syncthreads();\n", + " if (((int)threadIdx.x) == 0) {\n", + " C[((int)blockIdx.x)] = (half)(((volatile half*)red_result)[0]);\n", + " }\n", + "}\n", + "\n", + "\n", + "extern \"C\" void init() {\n", + " \n", + "}\n", + "\n", + "extern \"C\" void call(half* __restrict__ A, int8_t* __restrict__ B, half* __restrict__ C, cudaStream_t stream=cudaStreamDefault) {\n", + "matmul_m1n16384k16384_f16xi4_simt_kernel<<>>(A, B, C);\n", + "}\n", + "\n", + "latency=0.14325749999999998\n" + ] + } + ], + "source": [ + "matmul_config = bitblas.MatmulConfig(\n", + " M=1, # M dimension\n", + " N=16384, # N dimension\n", + " K=16384, # K dimension\n", + " A_dtype=\"float16\", # activation A dtype\n", + " W_dtype=\"int4\", # weight W dtype\n", + " fast_decoding=True, # Disable Fast Decoding\n", + ")\n", + "\n", + "\n", + "matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)\n", + "\n", + "matmul.hardware_aware_finetune(topk=20, parallel_build=True)\n", + "\n", + "print(matmul.get_source())\n", + "latency = matmul.profile_latency()\n", + "print(f\"{latency=}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "feeeecb3-339d-42cc-9069-ee8645efc7d7", + "metadata": {}, + "source": [ + "## Performance of Fast Dequantization on A100-80G\n", + "\n", + "![image.png](./img/FastDequantization_EXP.png)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/4.dynamic_shape_codegen.ipynb b/tutorials/4.dynamic_shape_codegen.ipynb new file mode 100644 index 000000000..69554f91b --- /dev/null +++ b/tutorials/4.dynamic_shape_codegen.ipynb @@ -0,0 +1,1965 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "2e486d95-6c70-4c68-b7bb-83b97579e316", + "metadata": {}, + "source": [ + "# Code Generation with Dynamic Shape\n", + "**Large Language Models are dynamic**\n", + "![image.png](./img/DynamicTuning.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e250c4a-d27f-4426-b67a-a29b1962cf94", + "metadata": {}, + "outputs": [], + "source": [ + "import bitblas\n", + "import torch\n", + "\n", + "# enabling debug output\n", + "\n", + "bitblas.set_log_level(\"Debug\")" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4258ce53-a8e7-4cd3-80bc-7e71518bb9e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-25 08:25:39 [BitBLAS:INFO]: Auto detected target: nvidia/geforce-rtx-4090\n", + "\n", + "template \n", + "__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)\n", + "{\n", + " uint *h = reinterpret_cast(B_local_decode);\n", + "\n", + " static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;\n", + " static constexpr uint BOTTOM_MASK = 0x000f000f;\n", + " static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;\n", + " static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;\n", + " uint const i4s = *reinterpret_cast(_i4s);\n", + "#pragma unroll\n", + " for (int i = 0; i < (N / 2); i++)\n", + " {\n", + "\n", + " asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n", + " : \"=r\"(h[i])\n", + " : \"r\"(i4s >> (4 * i)), \"n\"(BOTTOM_MASK), \"n\"(FP16_TOP_MAGIC_NUM), \"n\"(immLut));\n", + " asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[i]) : \"r\"(h[i]), \"r\"(MEDIAN_NUM));\n", + " }\n", + "}\n", + "\n", + "template \n", + "__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)\n", + "{\n", + " decode_i4b_to_f16(_i4s, B_local_decode, N);\n", + "}\n", + "\n", + "template \n", + "__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8)\n", + "{\n", + " decode_i4b_to_f16(_i4u, B_local_decode, N);\n", + "}\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n", + "\n", + "typedef unsigned short uint16_t;\n", + "typedef unsigned char uint8_t;\n", + "typedef signed char int8_t;\n", + "typedef int int32_t;\n", + "typedef unsigned long long uint64_t;\n", + "typedef unsigned int uint32_t;\n", + "\n", + "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n", + "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n", + "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n", + "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n", + " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " }\n", + "\n", + "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const T& a) { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " }\n", + "\n", + "class TVM_ALIGNED(2) half {\n", + " public:\n", + " uint16_t half_;\n", + "\n", + " static TVM_XINLINE half Binary(uint16_t value) {\n", + " half res;\n", + " res.half_ = value;\n", + " return res;\n", + " }\n", + "\n", + " TVM_XINLINE half() {}\n", + "\n", + " TVM_XINLINE half(const float& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const long long& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n", + "\n", + " TVM_XINLINE operator float() const { \\\n", + " return float(half2float(half_)); \\\n", + " } \\\n", + " TVM_XINLINE operator float() const volatile { \\\n", + " return float(half2float(half_)); \\\n", + " }\n", + "\n", + "\n", + " TVM_HALF_ASSIGNOP(+=, +)\n", + " TVM_HALF_ASSIGNOP(-=, -)\n", + " TVM_HALF_ASSIGNOP(*=, *)\n", + " TVM_HALF_ASSIGNOP(/=, /)\n", + "\n", + " TVM_XINLINE half operator+() {\n", + " return *this;\n", + " }\n", + "\n", + " TVM_XINLINE half operator-() {\n", + " return half(-float(*this));\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) volatile {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) volatile {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " private:\n", + " union Bits {\n", + " float f;\n", + " int32_t si;\n", + " uint32_t ui;\n", + " };\n", + "\n", + " static int const fp16FractionBits = 10;\n", + " static int const fp32FractionBits = 23;\n", + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n", + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n", + " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n", + " static int const shiftSign = 16;\n", + " static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n", + "\n", + " static int32_t const infN = 0x7F800000; // flt32 infinity\n", + " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n", + " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n", + " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n", + " static int32_t const signN = 0x80000000; // flt32 sign bit\n", + "\n", + " static int32_t const infC = infN >> shift;\n", + " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n", + " static int32_t const maxC = maxN >> shift;\n", + " static int32_t const minC = minN >> shift;\n", + " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n", + "\n", + " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n", + " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n", + "\n", + " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n", + " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n", + "\n", + " static int32_t const maxD = infC - maxC - 1;\n", + " static int32_t const minD = minC - subC - 1;\n", + "\n", + " TVM_XINLINE uint16_t float2half(const float& value) const {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " // Same as above routine, except for addition of volatile keyword\n", + " TVM_XINLINE uint16_t float2half(\n", + " const volatile float& value) const volatile {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(const uint16_t& value) const {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(\n", + " const volatile uint16_t& value) const volatile {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE void constructor(const T& value) {\n", + " half_ = float2half(float(value));\n", + " }\n", + "};\n", + "\n", + "TVM_HALF_OPERATOR(half, +)\n", + "TVM_HALF_OPERATOR(half, -)\n", + "TVM_HALF_OPERATOR(half, *)\n", + "TVM_HALF_OPERATOR(half, /)\n", + "TVM_HALF_OPERATOR(bool, >)\n", + "TVM_HALF_OPERATOR(bool, <)\n", + "TVM_HALF_OPERATOR(bool, >=)\n", + "TVM_HALF_OPERATOR(bool, <=)\n", + "\n", + "TVM_XINLINE half __float2half_rn(const float a) {\n", + " return half(a);\n", + "}\n", + "#else\n", + "#include \n", + "__device__ half max(half a, half b)\n", + "{\n", + " return __hgt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "__device__ half min(half a, half b)\n", + "{\n", + " return __hlt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "#endif\n", + "\n", + "\n", + "// Pack two half values.\n", + "static inline __device__ __host__ unsigned\n", + "__pack_half2(const half x, const half y) {\n", + " unsigned v0 = *((unsigned short *)&x);\n", + " unsigned v1 = *((unsigned short *)&y);\n", + " return (v1 << 16) | v0;\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float tmp_y = __half2float(y); \\\n", + " float result = FP32_MATH_NAME(tmp_x, tmp_y); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float result = FP32_MATH_NAME(tmp_x); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "// Some fp16 math functions are not supported in cuda_fp16.h,\n", + "// so we define them here to make sure the generated CUDA code\n", + "// is valid.\n", + "#if defined(__CUDA_ARCH__)\n", + "#if (__CUDA_ARCH__ >= 530)\n", + "CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)\n", + "#else\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)\n", + "#endif\n", + "#endif\n", + "\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY\n", + "\n", + "struct __align__(8) half4 {\n", + " __half x, y, z, w;\n", + " __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}\n", + " __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}\n", + "\n", + "};\n", + "__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {\n", + " return half4(x, y, z, w);\n", + "}\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n", + "#include \n", + "\n", + "\n", + "#if defined(__CUDACC_RTC__)\n", + "#define __SM_61_INTRINSICS_DECL__ __device__\n", + "#else /* !__CUDACC_RTC__ */\n", + "#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__\n", + "#endif /* __CUDACC_RTC__ */\n", + "\n", + "#ifndef __CUDA_ARCH__\n", + "#define __DEF_IF_HOST { }\n", + "#else /* !__CUDA_ARCH__ */\n", + "#define __DEF_IF_HOST ;\n", + "#endif /* __CUDA_ARCH__ */\n", + "\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) __DEF_IF_HOST\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) __DEF_IF_HOST\n", + "\n", + "#undef __DEF_IF_HOST\n", + "\n", + "#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {\n", + " int ret;\n", + " asm volatile (\"dp4a.u32.s32 %0, %1, %2, %3;\" : \"=r\"(ret) : \"r\"(srcA), \"r\"(srcB), \"r\"(c));\n", + " return ret;\n", + "}\n", + "\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {\n", + " int ret;\n", + " asm volatile (\"dp4a.s32.u32 %0, %1, %2, %3;\" : \"=r\"(ret) : \"r\"(srcA), \"r\"(srcB), \"r\"(c));\n", + " return ret;\n", + "}\n", + "#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */\n", + "\n", + "#undef __SM_61_INTRINSICS_DECL__\n", + "\n", + "#endif\n", + "__forceinline__ __device__ unsigned int\n", + "cast_smem_ptr_to_int(const void* const smem_ptr)\n", + "{\n", + " unsigned int smem_int;\n", + " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; cvt.u32.u64 %0, smem_int; }\"\n", + " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n", + " return smem_int;\n", + "}\n", + "\n", + "#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n", + " (__CUDACC_VER_MAJOR__ > 11))\n", + "#define TVM_ENABLE_L2_PREFETCH 1\n", + "#else\n", + "#define TVM_ENABLE_L2_PREFETCH 0\n", + "#endif\n", + "\n", + "#ifdef _WIN32\n", + " using uint = unsigned int;\n", + " using uchar = unsigned char;\n", + " using ushort = unsigned short;\n", + " using int64_t = long long;\n", + " using uint64_t = unsigned long long;\n", + "#else\n", + " #define uint unsigned int\n", + " #define uchar unsigned char\n", + " #define ushort unsigned short\n", + " #define int64_t long long\n", + " #define uint64_t unsigned long long\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) \n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1\n", + "#else\n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0\n", + "#endif\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_n1024k1024_f16xi4_default_kernel(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_n1024k1024_f16xi4_default_kernel(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_pad_shared_dyn_warp[128];\n", + " signed char B_local[4];\n", + " half B_decode_reindex_local[8];\n", + " half A_reindex_pad_shared_dyn_warp[32];\n", + " half B_decode_reindex_shared_dyn_warp[32];\n", + " for (int var = 0; var < 1; ++var) {\n", + " for (int ax1_0_3_init = 0; ax1_0_3_init < 4; ++ax1_0_3_init) {\n", + " for (int ax2_0_3_init = 0; ax2_0_3_init < 4; ++ax2_0_3_init) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_pad_shared_dyn_warp[((ax1_0_3_init * 32) + (ax2_0_3_init * 8)) + i] = 0.0;}\n", + ";\n", + " }\n", + " }\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 32; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 8; ++ax0_ax1_ax2_fused_0) {\n", + " half4 condval;\n", + " if (((((((((int)blockIdx.y) * 128) + (ax0_ax1_ax2_fused_0 * 16)) + (((int)threadIdx.y) * 8)) + (((int)threadIdx.z) * 4)) + (((int)threadIdx.x) >> 3)) < m)) {\n", + " condval = *(half4*)(A + (((((((((int)blockIdx.y) * 131072) + (ax0_ax1_ax2_fused_0 * 16384)) + (((int)threadIdx.y) * 8192)) + (((int)threadIdx.z) * 4096)) + ((((int)threadIdx.x) >> 3) * 1024)) + (ax3_0_0 * 32)) + ((((int)threadIdx.x) & 7) * 4)));\n", + " } else {\n", + " condval = make_half4(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f));\n", + " }\n", + " *(half4*)(((half*)buf_dyn_shmem) + (((((((ax0_ax1_ax2_fused_0 * 512) + (((int)threadIdx.y) * 256)) + (((int)threadIdx.z) * 128)) + ((((int)threadIdx.x) >> 3) * 32)) + ((((((int)threadIdx.x) & 7) >> 1) ^ ((((int)threadIdx.z) * 2) + (((int)threadIdx.x) >> 4))) * 8)) + ((((int)threadIdx.x) & 1) * 4)) + 4096)) = condval;\n", + " }\n", + " for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 4; ++ax1_ax2_0_fused_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + (((((((((int)blockIdx.x) * 65536) + (ax1_ax2_0_fused_0 * 16384)) + (((int)threadIdx.y) * 8192)) + (((int)threadIdx.z) * 4096)) + ((((int)threadIdx.x) >> 2) * 512)) + (ax3_0_0 * 16)) + ((((int)threadIdx.x) & 3) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_reindex_local[0])), 8);\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((((ax1_ax2_0_fused_0 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8))) = *(uint4*)(B_decode_reindex_local + 0);\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 2; ++ax3_0_1) {\n", + " for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.y) * 2048) + (ax1_0 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 4096)])) + 0)));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.y) * 2048) + (ax1_0 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 4096)])) + 0))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + (ax1_0 * 8)))[0]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + (ax1_0 * 8)))[1]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + (ax1_0 * 8)))[2]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + (ax1_0 * 8)))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + " }\n", + " for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.z) * 2048) + (ax1_0_1 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0)));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.z) * 2048) + (ax1_0_1 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + (ax1_0_1 * 8)))[0]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + (ax1_0_1 * 8)))[1]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + (ax1_0_1 * 8)))[2]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + (ax1_0_1 * 8)))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + " }\n", + " for (int ax1_0_3 = 0; ax1_0_3 < 4; ++ax1_0_3) {\n", + " for (int ax2_0_3 = 0; ax2_0_3 < 4; ++ax2_0_3) {\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + (ax2_0_3 * 8)))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + (ax2_0_3 * 8)))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + ((ax1_0_3 * 32) + (ax2_0_3 * 8))))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + (ax1_0_3 * 8)))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + ((ax2_0_3 * 8) + 4)))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + ((ax2_0_3 * 8) + 4)))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + (((ax1_0_3 * 32) + (ax2_0_3 * 8)) + 4)))[1]));\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " for (int ax0 = 0; ax0 < 4; ++ax0) {\n", + " __syncthreads();\n", + " for (int ax1 = 0; ax1 < 4; ++ax1) {\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[(((((int)threadIdx.y) * 8192) + (((int)threadIdx.z) * 1024)) + (ax1 * 256))]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_pad_shared_dyn_warp[((ax0 * 32) + (ax1 * 8)) + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 4; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " if (((((((int)blockIdx.y) * 128) + (((int)threadIdx.y) * 64)) + (ax0 * 16)) + (((int)threadIdx.x) >> 1)) < m) {\n", + " *(uint4*)(C + ((((((((((int)blockIdx.y) * 131072) + (((int)threadIdx.y) * 65536)) + (ax0 * 16384)) + ((((int)threadIdx.x) >> 1) * 1024)) + (((int)blockIdx.x) * 128)) + (((int)threadIdx.z) * 64)) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + ((((((int)threadIdx.y) * 8192) + (((int)threadIdx.z) * 1024)) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 256)) + (((int)threadIdx.x) * 8)));\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "\n", + "extern \"C\" void init() {\n", + " \n", + " cudaFuncSetAttribute(matmul_n1024k1024_f16xi4_default_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 20480);\n", + "\n", + "}\n", + "\n", + "extern \"C\" void call(half* __restrict__ A, int8_t* __restrict__ B, half* __restrict__ C, int m, cudaStream_t stream=cudaStreamDefault) {\n", + "if (m == 0) return; \n", + "\t\tmatmul_n1024k1024_f16xi4_default_kernel<<>>(A, B, C, m);\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "matmul_config = bitblas.MatmulConfig(\n", + " M=[1, 16, 32, 64, 128, 256], # M dimension, default value is from 1 to 1024\n", + " N=1024, # N dimension\n", + " K=1024, # K dimension\n", + " A_dtype=\"float16\", # activation A dtype\n", + " W_dtype=\"int4\", # weight W dtype\n", + ")\n", + "\n", + "matmul = bitblas.Matmul(config=matmul_config, enable_tuning=False)\n", + "print(matmul.get_source())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d701fa2e-ef3b-49c6-a54a-42c8f5219296", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-10-25 08:27:16 [BitBLAS:INFO]: Start fast tuning with dynamic range\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-25 08:27:16 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.\n", + "2024-10-25 08:27:41 [BitBLAS:DEBUG]: LocalBuilder: An exception occurred Traceback (most recent call last):\n", + " File \"/root/BitBLAS/3rdparty/tvm/python/tvm/exec/popen_worker.py\", line 87, in main\n", + " result = fn(*args, **kwargs)\n", + " File \"/root/BitBLAS/bitblas/base/utils.py\", line 257, in _build\n", + " rt_mod = tvm.build(mod, ta\t...\tm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::RampNode const*, std::ostream&)\n", + " File \"/root/BitBLAS/3rdparty/tvm/src/target/source/codegen_cuda.cc\", line 1226\n", + "ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.\n", + "\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.007 ms\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.006 ms\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.007 ms\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Evaluation with config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}\n", + "2024-10-25 08:27:42 [BitBLAS:INFO]: Time cost of this config: 0.005 ms\n", + "2024-10-25 08:27:43 [BitBLAS:DEBUG]: Apply config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:27:43 [BitBLAS:DEBUG]: Apply config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:27:43 [BitBLAS:DEBUG]: Apply config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:27:43 [BitBLAS:DEBUG]: Apply config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:27:43 [BitBLAS:DEBUG]: Apply config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:27:43 [BitBLAS:DEBUG]: Apply config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Evaluation with config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Evaluation with config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Time cost of this config: 0.007 ms\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Time cost of this config: 0.049 ms\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Evaluation with config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:27 [BitBLAS:INFO]: Time cost of this config: 0.024 ms\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:28 [BitBLAS:DEBUG]: Apply config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:28:29 [BitBLAS:DEBUG]: Apply config {'block': [32, 512], 'warp': [16, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.007 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.024 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.048 ms\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Evaluation with config {'block': [32, 512], 'warp': [16, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:INFO]: Time cost of this config: 0.040 ms\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:16 [BitBLAS:DEBUG]: Apply config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [32, 512], 'warp': [16, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:29:17 [BitBLAS:DEBUG]: Apply config {'block': [64, 512], 'warp': [32, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.009 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.021 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.048 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.024 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.031 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [32, 512], 'warp': [16, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.040 ms\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Evaluation with config {'block': [64, 512], 'warp': [32, 256], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:04 [BitBLAS:INFO]: Time cost of this config: 0.055 ms\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:05 [BitBLAS:DEBUG]: Apply config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:06 [BitBLAS:DEBUG]: Apply config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.008 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.021 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.024 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.020 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.018 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.048 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Evaluation with config {'block': [16, 512], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:54 [BitBLAS:INFO]: Time cost of this config: 0.029 ms\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:55 [BitBLAS:DEBUG]: Apply config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:30:56 [BitBLAS:DEBUG]: Apply config {'block': [256, 16], 'warp': [64, 16], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 4}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [16, 32], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [32, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [16, 16], 'warp': [16, 16], 'rstep': [256], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.013 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.014 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.010 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.011 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.013 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.016 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.012 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.021 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.015 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.016 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 4, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.024 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.018 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.021 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 8}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.030 ms\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Evaluation with config {'block': [256, 16], 'warp': [64, 16], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_decode_reindex': 4}}\n", + "2024-10-25 08:31:43 [BitBLAS:INFO]: Time cost of this config: 0.031 ms\n", + "\n", + "template \n", + "__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)\n", + "{\n", + " uint *h = reinterpret_cast(B_local_decode);\n", + "\n", + " static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;\n", + " static constexpr uint BOTTOM_MASK = 0x000f000f;\n", + " static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400;\n", + " static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400;\n", + " uint const i4s = *reinterpret_cast(_i4s);\n", + "#pragma unroll\n", + " for (int i = 0; i < (N / 2); i++)\n", + " {\n", + "\n", + " asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n", + " : \"=r\"(h[i])\n", + " : \"r\"(i4s >> (4 * i)), \"n\"(BOTTOM_MASK), \"n\"(FP16_TOP_MAGIC_NUM), \"n\"(immLut));\n", + " asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[i]) : \"r\"(h[i]), \"r\"(MEDIAN_NUM));\n", + " }\n", + "}\n", + "\n", + "template \n", + "__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8)\n", + "{\n", + " decode_i4b_to_f16(_i4s, B_local_decode, N);\n", + "}\n", + "\n", + "template \n", + "__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8)\n", + "{\n", + " decode_i4b_to_f16(_i4u, B_local_decode, N);\n", + "}\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)\n", + "\n", + "typedef unsigned short uint16_t;\n", + "typedef unsigned char uint8_t;\n", + "typedef signed char int8_t;\n", + "typedef int int32_t;\n", + "typedef unsigned long long uint64_t;\n", + "typedef unsigned int uint32_t;\n", + "\n", + "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n", + "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n", + "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n", + "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n", + " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n", + " return RTYPE(float(a) OP float(b)); \\\n", + " }\n", + "\n", + "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const T& a) { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " } \\\n", + " template \\\n", + " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n", + " return *this = half(float(*this) OP float(a)); \\\n", + " }\n", + "\n", + "class TVM_ALIGNED(2) half {\n", + " public:\n", + " uint16_t half_;\n", + "\n", + " static TVM_XINLINE half Binary(uint16_t value) {\n", + " half res;\n", + " res.half_ = value;\n", + " return res;\n", + " }\n", + "\n", + " TVM_XINLINE half() {}\n", + "\n", + " TVM_XINLINE half(const float& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const long long& value) { constructor(value); }\n", + " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n", + "\n", + " TVM_XINLINE operator float() const { \\\n", + " return float(half2float(half_)); \\\n", + " } \\\n", + " TVM_XINLINE operator float() const volatile { \\\n", + " return float(half2float(half_)); \\\n", + " }\n", + "\n", + "\n", + " TVM_HALF_ASSIGNOP(+=, +)\n", + " TVM_HALF_ASSIGNOP(-=, -)\n", + " TVM_HALF_ASSIGNOP(*=, *)\n", + " TVM_HALF_ASSIGNOP(/=, /)\n", + "\n", + " TVM_XINLINE half operator+() {\n", + " return *this;\n", + " }\n", + "\n", + " TVM_XINLINE half operator-() {\n", + " return half(-float(*this));\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " TVM_XINLINE half operator=(const half& a) volatile {\n", + " half_ = a.half_;\n", + " return a;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE half operator=(const T& a) volatile {\n", + " return *this = half(a);\n", + " }\n", + "\n", + " private:\n", + " union Bits {\n", + " float f;\n", + " int32_t si;\n", + " uint32_t ui;\n", + " };\n", + "\n", + " static int const fp16FractionBits = 10;\n", + " static int const fp32FractionBits = 23;\n", + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n", + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n", + " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n", + " static int const shiftSign = 16;\n", + " static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n", + "\n", + " static int32_t const infN = 0x7F800000; // flt32 infinity\n", + " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n", + " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n", + " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n", + " static int32_t const signN = 0x80000000; // flt32 sign bit\n", + "\n", + " static int32_t const infC = infN >> shift;\n", + " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n", + " static int32_t const maxC = maxN >> shift;\n", + " static int32_t const minC = minN >> shift;\n", + " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n", + "\n", + " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n", + " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n", + "\n", + " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n", + " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n", + "\n", + " static int32_t const maxD = infC - maxC - 1;\n", + " static int32_t const minD = minC - subC - 1;\n", + "\n", + " TVM_XINLINE uint16_t float2half(const float& value) const {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " // Same as above routine, except for addition of volatile keyword\n", + " TVM_XINLINE uint16_t float2half(\n", + " const volatile float& value) const volatile {\n", + " Bits v;\n", + " v.f = value;\n", + " uint32_t sign = v.si & signN; // grab sign bit\n", + " v.si ^= sign; // clear sign bit from v\n", + " sign >>= shiftSign; // logical shift sign to fp16 position\n", + "\n", + " if (v.si <= maxZ) {\n", + " // Handle eventual zeros here to ensure\n", + " // vshift will not exceed 32 below.\n", + " v.ui = 0;\n", + " } else if (v.si < minN) {\n", + " // Handle denorms\n", + " uint32_t exp32 = v.ui >> fp32FractionBits;\n", + " int32_t exp16 = exp32 - expAdjust;\n", + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n", + " // Smaller (so negative) exp16 values should result in greater right shifts.\n", + " uint32_t vshift = 1 - exp16;\n", + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n", + " v.ui = significand >> vshift;\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;\n", + " } else if (v.si <= maxN) {\n", + " // Handle norms\n", + " v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;\n", + " v.ui -= expAdjust << fp32FractionBits;\n", + " } else if (v.si <= infN) {\n", + " v.si = infN;\n", + " } else if (v.si < nanN) {\n", + " v.si = nanN;\n", + " }\n", + "\n", + " v.ui >>= shift;\n", + " return sign | (v.ui & 0x7fff);\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(const uint16_t& value) const {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " TVM_XINLINE float half2float(\n", + " const volatile uint16_t& value) const volatile {\n", + " Bits v;\n", + " v.ui = value;\n", + " int32_t sign = v.si & signC;\n", + " v.si ^= sign;\n", + " sign <<= shiftSign;\n", + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n", + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n", + " Bits s;\n", + " s.si = mulC;\n", + " s.f *= v.si;\n", + " int32_t mask = -(norC > v.si);\n", + " v.si <<= shift;\n", + " v.si ^= (s.si ^ v.si) & mask;\n", + " v.si |= sign;\n", + " return v.f;\n", + " }\n", + "\n", + " template\n", + " TVM_XINLINE void constructor(const T& value) {\n", + " half_ = float2half(float(value));\n", + " }\n", + "};\n", + "\n", + "TVM_HALF_OPERATOR(half, +)\n", + "TVM_HALF_OPERATOR(half, -)\n", + "TVM_HALF_OPERATOR(half, *)\n", + "TVM_HALF_OPERATOR(half, /)\n", + "TVM_HALF_OPERATOR(bool, >)\n", + "TVM_HALF_OPERATOR(bool, <)\n", + "TVM_HALF_OPERATOR(bool, >=)\n", + "TVM_HALF_OPERATOR(bool, <=)\n", + "\n", + "TVM_XINLINE half __float2half_rn(const float a) {\n", + " return half(a);\n", + "}\n", + "#else\n", + "#include \n", + "__device__ half max(half a, half b)\n", + "{\n", + " return __hgt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "__device__ half min(half a, half b)\n", + "{\n", + " return __hlt(__half(a), __half(b)) ? a : b;\n", + "}\n", + "#endif\n", + "\n", + "\n", + "// Pack two half values.\n", + "static inline __device__ __host__ unsigned\n", + "__pack_half2(const half x, const half y) {\n", + " unsigned v0 = *((unsigned short *)&x);\n", + " unsigned v1 = *((unsigned short *)&y);\n", + " return (v1 << 16) | v0;\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float tmp_y = __half2float(y); \\\n", + " float result = FP32_MATH_NAME(tmp_x, tmp_y); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \\\n", + "static inline __device__ __host__ half HALF_MATH_NAME(half x) { \\\n", + " float tmp_x = __half2float(x); \\\n", + " float result = FP32_MATH_NAME(tmp_x); \\\n", + " return __float2half(result); \\\n", + "}\n", + "\n", + "// Some fp16 math functions are not supported in cuda_fp16.h,\n", + "// so we define them here to make sure the generated CUDA code\n", + "// is valid.\n", + "#if defined(__CUDA_ARCH__)\n", + "#if (__CUDA_ARCH__ >= 530)\n", + "CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)\n", + "#else\n", + "CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)\n", + "#endif\n", + "#endif\n", + "\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY\n", + "#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY\n", + "\n", + "struct __align__(8) half4 {\n", + " __half x, y, z, w;\n", + " __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}\n", + " __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}\n", + "\n", + "};\n", + "__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {\n", + " return half4(x, y, z, w);\n", + "}\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)\n", + "#define __shfl_sync(mask, var, lane, width) \\\n", + " __shfl((var), (lane), (width))\n", + "\n", + "#define __shfl_down_sync(mask, var, offset, width) \\\n", + " __shfl_down((var), (offset), (width))\n", + "\n", + "#define __shfl_up_sync(mask, var, offset, width) \\\n", + " __shfl_up((var), (offset), (width))\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n", + "#include \n", + "\n", + "\n", + "#if defined(__CUDACC_RTC__)\n", + "#define __SM_61_INTRINSICS_DECL__ __device__\n", + "#else /* !__CUDACC_RTC__ */\n", + "#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__\n", + "#endif /* __CUDACC_RTC__ */\n", + "\n", + "#ifndef __CUDA_ARCH__\n", + "#define __DEF_IF_HOST { }\n", + "#else /* !__CUDA_ARCH__ */\n", + "#define __DEF_IF_HOST ;\n", + "#endif /* __CUDA_ARCH__ */\n", + "\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) __DEF_IF_HOST\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) __DEF_IF_HOST\n", + "\n", + "#undef __DEF_IF_HOST\n", + "\n", + "#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {\n", + " int ret;\n", + " asm volatile (\"dp4a.u32.s32 %0, %1, %2, %3;\" : \"=r\"(ret) : \"r\"(srcA), \"r\"(srcB), \"r\"(c));\n", + " return ret;\n", + "}\n", + "\n", + "__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {\n", + " int ret;\n", + " asm volatile (\"dp4a.s32.u32 %0, %1, %2, %3;\" : \"=r\"(ret) : \"r\"(srcA), \"r\"(srcB), \"r\"(c));\n", + " return ret;\n", + "}\n", + "#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */\n", + "\n", + "#undef __SM_61_INTRINSICS_DECL__\n", + "\n", + "#endif\n", + "__forceinline__ __device__ unsigned int\n", + "cast_smem_ptr_to_int(const void* const smem_ptr)\n", + "{\n", + " unsigned int smem_int;\n", + " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; cvt.u32.u64 %0, smem_int; }\"\n", + " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n", + " return smem_int;\n", + "}\n", + "\n", + "#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n", + " (__CUDACC_VER_MAJOR__ > 11))\n", + "#define TVM_ENABLE_L2_PREFETCH 1\n", + "#else\n", + "#define TVM_ENABLE_L2_PREFETCH 0\n", + "#endif\n", + "\n", + "#ifdef _WIN32\n", + " using uint = unsigned int;\n", + " using uchar = unsigned char;\n", + " using ushort = unsigned short;\n", + " using int64_t = long long;\n", + " using uint64_t = unsigned long long;\n", + "#else\n", + " #define uint unsigned int\n", + " #define uchar unsigned char\n", + " #define ushort unsigned short\n", + " #define int64_t long long\n", + " #define uint64_t unsigned long long\n", + "#endif\n", + "\n", + "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) \n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1\n", + "#else\n", + "#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0\n", + "#endif\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx16x128x64w16x32_opt_m_16(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx16x64x128w16x16_opt_m_32(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_n1024k1024_f16xi4_simt_opt_m_1(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_128(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_256(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx32x16x256w16x16_opt_m_64(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m);\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx16x128x64w16x32_opt_m_16(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_pad_shared_dyn_warp[8];\n", + " signed char B_local[4];\n", + " half B_decode_reindex_local[8];\n", + " half A_reindex_pad_shared_dyn_warp[8];\n", + " half B_decode_reindex_shared_dyn_warp[8];\n", + " for (int var = 0; var < 1; ++var) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_pad_shared_dyn_warp[0 + i] = 0.0;}\n", + ";\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 4; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 16; ++ax0_ax1_ax2_fused_0) {\n", + " uint4 condval;\n", + " if ((((((int)blockIdx.y) * 16) + ax0_ax1_ax2_fused_0) < m)) {\n", + " condval = *(uint4*)(A + ((((((int)blockIdx.y) * 16384) + (ax0_ax1_ax2_fused_0 * 1024)) + (ax3_0_0 * 256)) + (((int)threadIdx.x) * 8)));\n", + " } else {\n", + " condval = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)));\n", + " }\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax0_ax1_ax2_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 256)) = condval;\n", + " }\n", + " for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 16; ++ax1_ax2_0_fused_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + ((((((int)blockIdx.x) * 8192) + (ax1_ax2_0_fused_0 * 512)) + (ax3_0_0 * 128)) + (((int)threadIdx.x) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_reindex_local[0])), 8);\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax1_ax2_0_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 4480)) = *(uint4*)(B_decode_reindex_local + 0);\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 16; ++ax3_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1]));\n", + " }\n", + " }\n", + " }\n", + " __syncthreads();\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[0]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_pad_shared_dyn_warp[0 + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 1; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " if (((((int)blockIdx.y) * 16) + (((int)threadIdx.x) >> 1)) < m) {\n", + " *(uint4*)(C + ((((((int)blockIdx.y) * 16384) + ((((int)threadIdx.x) >> 1) * 1024)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + (((int)threadIdx.x) * 8));\n", + " }\n", + " }\n", + "}\n", + "\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx16x64x128w16x16_opt_m_32(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_pad_shared_dyn_warp[8];\n", + " signed char B_local[4];\n", + " half B_decode_reindex_local[8];\n", + " half A_reindex_pad_shared_dyn_warp[8];\n", + " half B_decode_reindex_shared_dyn_warp[8];\n", + " for (int var = 0; var < 1; ++var) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_pad_shared_dyn_warp[0 + i] = 0.0;}\n", + ";\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 4; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 16; ++ax0_ax1_ax2_fused_0) {\n", + " uint4 condval;\n", + " if ((((((int)blockIdx.y) * 16) + ax0_ax1_ax2_fused_0) < m)) {\n", + " condval = *(uint4*)(A + ((((((int)blockIdx.y) * 16384) + (ax0_ax1_ax2_fused_0 * 1024)) + (ax3_0_0 * 256)) + (((int)threadIdx.x) * 8)));\n", + " } else {\n", + " condval = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)));\n", + " }\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax0_ax1_ax2_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 256)) = condval;\n", + " }\n", + " for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 16; ++ax1_ax2_0_fused_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + ((((((int)blockIdx.x) * 8192) + (ax1_ax2_0_fused_0 * 512)) + (ax3_0_0 * 128)) + (((int)threadIdx.x) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_reindex_local[0])), 8);\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax1_ax2_0_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 4480)) = *(uint4*)(B_decode_reindex_local + 0);\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 16; ++ax3_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1]));\n", + " }\n", + " }\n", + " }\n", + " __syncthreads();\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[0]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_pad_shared_dyn_warp[0 + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 1; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " if (((((int)blockIdx.y) * 16) + (((int)threadIdx.x) >> 1)) < m) {\n", + " *(uint4*)(C + ((((((int)blockIdx.y) * 16384) + ((((int)threadIdx.x) >> 1) * 1024)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + (((int)threadIdx.x) * 8));\n", + " }\n", + " }\n", + "}\n", + "\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_n1024k1024_f16xi4_simt_opt_m_1(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " half in_thread_C_local[1];\n", + " signed char B_local[4];\n", + " half B_decode_local[8];\n", + " half A_local[8];\n", + " __shared__ half red_result[2];\n", + " in_thread_C_local[0] = __float2half_rn(0.000000e+00f);\n", + " for (int ax2_0 = 0; ax2_0 < 2; ++ax2_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + ((((((int)blockIdx.x) * 1024) + (((int)threadIdx.y) * 512)) + (ax2_0 * 256)) + (((int)threadIdx.x) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_local[0])), 8);\n", + " *(uint4*)(A_local + 0) = *(uint4*)(A + (((((int)blockIdx.z) * 1024) + (ax2_0 * 512)) + (((int)threadIdx.x) * 8)));\n", + " for (int ax2_2_0 = 0; ax2_2_0 < 4; ++ax2_2_0) {\n", + " for (int ax2_2_1 = 0; ax2_2_1 < 2; ++ax2_2_1) {\n", + " in_thread_C_local[0] = (in_thread_C_local[0] + (A_local[((ax2_2_0 * 2) + ax2_2_1)] * B_decode_local[((ax2_2_0 * 2) + ax2_2_1)]));\n", + " }\n", + " }\n", + " }\n", + " half red_buf0[1];\n", + " uint mask[1];\n", + " half t0[1];\n", + " half red_buf0_1[1];\n", + " uint mask_1[1];\n", + " half t0_1[1];\n", + " __shared__ half red_buf_staging[4];\n", + " red_buf0_1[0] = in_thread_C_local[0];\n", + " mask_1[0] = __activemask();\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);\n", + " red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);\n", + " if ((((int)threadIdx.x) % 32) == 0) {\n", + " red_buf_staging[((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 5))] = red_buf0_1[0];\n", + " }\n", + " __syncthreads();\n", + " if (((int)threadIdx.x) < 2) {\n", + " red_buf0[0] = red_buf_staging[((((int)threadIdx.y) * 2) + ((int)threadIdx.x))];\n", + " }\n", + " mask[0] = __activemask();\n", + " t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);\n", + " red_buf0[0] = (red_buf0[0] + t0[0]);\n", + " if (((int)threadIdx.x) == 0) {\n", + " ((volatile half*)red_result)[((int)threadIdx.y)] = red_buf0[0];\n", + " }\n", + " __syncthreads();\n", + " if (((int)threadIdx.x) == 0) {\n", + " C[(((((int)blockIdx.z) * 1024) + (((int)blockIdx.x) * 2)) + ((int)threadIdx.y))] = (half)(((volatile half*)red_result)[((int)threadIdx.y)]);\n", + " }\n", + "}\n", + "\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_128(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_pad_shared_dyn_warp[8];\n", + " signed char B_local[4];\n", + " half B_decode_reindex_local[8];\n", + " half A_reindex_pad_shared_dyn_warp[8];\n", + " half B_decode_reindex_shared_dyn_warp[8];\n", + " for (int var = 0; var < 1; ++var) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_pad_shared_dyn_warp[0 + i] = 0.0;}\n", + ";\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 4; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 16; ++ax0_ax1_ax2_fused_0) {\n", + " uint4 condval;\n", + " if ((((((int)blockIdx.y) * 16) + ax0_ax1_ax2_fused_0) < m)) {\n", + " condval = *(uint4*)(A + ((((((int)blockIdx.y) * 16384) + (ax0_ax1_ax2_fused_0 * 1024)) + (ax3_0_0 * 256)) + (((int)threadIdx.x) * 8)));\n", + " } else {\n", + " condval = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)));\n", + " }\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax0_ax1_ax2_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 256)) = condval;\n", + " }\n", + " for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 16; ++ax1_ax2_0_fused_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + ((((((int)blockIdx.x) * 8192) + (ax1_ax2_0_fused_0 * 512)) + (ax3_0_0 * 128)) + (((int)threadIdx.x) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_reindex_local[0])), 8);\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax1_ax2_0_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 4480)) = *(uint4*)(B_decode_reindex_local + 0);\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 16; ++ax3_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1]));\n", + " }\n", + " }\n", + " }\n", + " __syncthreads();\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[0]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_pad_shared_dyn_warp[0 + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 1; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " if (((((int)blockIdx.y) * 16) + (((int)threadIdx.x) >> 1)) < m) {\n", + " *(uint4*)(C + ((((((int)blockIdx.y) * 16384) + ((((int)threadIdx.x) >> 1) * 1024)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + (((int)threadIdx.x) * 8));\n", + " }\n", + " }\n", + "}\n", + "\n", + "extern \"C\" __global__ void __launch_bounds__(128) matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_256(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_pad_shared_dyn_warp[8];\n", + " signed char B_local[4];\n", + " half B_decode_reindex_local[8];\n", + " half A_reindex_pad_shared_dyn_warp[8];\n", + " half B_decode_reindex_shared_dyn_warp[8];\n", + " for (int var = 0; var < 1; ++var) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_pad_shared_dyn_warp[0 + i] = 0.0;}\n", + ";\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 8; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 2; ++ax0_ax1_ax2_fused_0) {\n", + " uint4 condval;\n", + " if ((((((((int)blockIdx.y) * 16) + (ax0_ax1_ax2_fused_0 * 8)) + (((int)threadIdx.z) * 2)) + (((int)threadIdx.x) >> 4)) < m)) {\n", + " condval = *(uint4*)(A + ((((((((int)blockIdx.y) * 16384) + (ax0_ax1_ax2_fused_0 * 8192)) + (((int)threadIdx.z) * 2048)) + ((((int)threadIdx.x) >> 4) * 1024)) + (ax3_0_0 * 128)) + ((((int)threadIdx.x) & 15) * 8)));\n", + " } else {\n", + " condval = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)));\n", + " }\n", + " *(uint4*)(((half*)buf_dyn_shmem) + ((((ax0_ax1_ax2_fused_0 * 1088) + (((int)threadIdx.z) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = condval;\n", + " }\n", + " for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 8; ++ax1_ax2_0_fused_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + ((((((((int)blockIdx.x) * 32768) + (ax1_ax2_0_fused_0 * 4096)) + (((int)threadIdx.z) * 1024)) + ((((int)threadIdx.x) >> 4) * 512)) + (ax3_0_0 * 64)) + ((((int)threadIdx.x) & 15) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_reindex_local[0])), 8);\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((((ax1_ax2_0_fused_0 * 1088) + (((int)threadIdx.z) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8)) + 2176)) = *(uint4*)(B_decode_reindex_local + 0);\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 8; ++ax3_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[(ax3_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[(ax3_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[(((((int)threadIdx.z) * 2176) + (ax3_0_1 * 16)) + 2176)])) + ((((((int)threadIdx.x) >> 4) * 1088) + ((((int)threadIdx.x) & 7) * 136)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[(((((int)threadIdx.z) * 2176) + (ax3_0_1 * 16)) + 2176)])) + ((((((int)threadIdx.x) >> 4) * 1088) + ((((int)threadIdx.x) & 7) * 136)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1]));\n", + " }\n", + " }\n", + " }\n", + " __syncthreads();\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[(((int)threadIdx.z) * 256)]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_pad_shared_dyn_warp[0 + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 4; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " if (((((int)blockIdx.y) * 16) + (((int)threadIdx.x) >> 1)) < m) {\n", + " *(uint4*)(C + (((((((int)blockIdx.y) * 16384) + ((((int)threadIdx.x) >> 1) * 1024)) + (((int)blockIdx.x) * 64)) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + ((ax0_ax1_ax2_ax3_ax4_fused_0 * 256) + (((int)threadIdx.x) * 8)));\n", + " }\n", + " }\n", + "}\n", + "\n", + "extern \"C\" __global__ void __launch_bounds__(32) matmul_n1024k1024_f16xi4_tcx32x16x256w16x16_opt_m_64(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ C, int m) {\n", + " extern __shared__ uchar buf_dyn_shmem[];\n", + " half C_reindex_pad_shared_dyn_warp[8];\n", + " signed char B_local[4];\n", + " half B_decode_reindex_local[8];\n", + " half A_reindex_pad_shared_dyn_warp[8];\n", + " half B_decode_reindex_shared_dyn_warp[8];\n", + " for (int var = 0; var < 1; ++var) {\n", + " for (int i = 0; i < 8; ++i) {\n", + "C_reindex_pad_shared_dyn_warp[0 + i] = 0.0;}\n", + ";\n", + " for (int ax3_0_0 = 0; ax3_0_0 < 4; ++ax3_0_0) {\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 16; ++ax0_ax1_ax2_fused_0) {\n", + " uint4 condval;\n", + " if ((((((int)blockIdx.y) * 16) + ax0_ax1_ax2_fused_0) < m)) {\n", + " condval = *(uint4*)(A + ((((((int)blockIdx.y) * 16384) + (ax0_ax1_ax2_fused_0 * 1024)) + (ax3_0_0 * 256)) + (((int)threadIdx.x) * 8)));\n", + " } else {\n", + " condval = make_uint4(__pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)), __pack_half2(__float2half_rn(0.000000e+00f), __float2half_rn(0.000000e+00f)));\n", + " }\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax0_ax1_ax2_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 256)) = condval;\n", + " }\n", + " for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 16; ++ax1_ax2_0_fused_0) {\n", + " *(int*)(B_local + 0) = *(int*)(B + ((((((int)blockIdx.x) * 8192) + (ax1_ax2_0_fused_0 * 512)) + (ax3_0_0 * 128)) + (((int)threadIdx.x) * 4)));\n", + " decode_i4s_to_f16((&(B_local[0])), (&(B_decode_reindex_local[0])), 8);\n", + " *(uint4*)(((half*)buf_dyn_shmem) + (((ax1_ax2_0_fused_0 * 264) + (((int)threadIdx.x) * 8)) + 4480)) = *(uint4*)(B_decode_reindex_local + 0);\n", + " }\n", + " __syncthreads();\n", + " for (int ax3_0_1 = 0; ax3_0_1 < 16; ++ax3_0_1) {\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 256)])) + (((((int)threadIdx.x) & 15) * 264) + ((((int)threadIdx.x) >> 4) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(A_reindex_pad_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " unsigned int addr;\n", + "#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST\n", + " addr = static_cast(__cvta_generic_to_shared((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))));\n", + "#else\n", + " __asm__ __volatile__(\n", + " \"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\\n\"\n", + " : \"=r\"(addr)\n", + " : \"l\"((void *)((&(((half*)buf_dyn_shmem)[((ax3_0_1 * 16) + 4480)])) + ((((((int)threadIdx.x) >> 4) * 2112) + ((((int)threadIdx.x) & 7) * 264)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))\n", + " );\n", + "#endif\n", + " __asm__ __volatile__(\n", + " \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n", + " \"{%0, %1, %2, %3}, [%4];\\n\"\n", + " : \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[1]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[2]), \"=r\"(((unsigned *)(B_decode_reindex_shared_dyn_warp + 0))[3])\n", + " : \"r\"(addr)\n", + " );\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 0))[1]));\n", + " }\n", + "\n", + " {\n", + " __asm__ __volatile__(\n", + " \"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16\"\n", + " \"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\\n\"\n", + " : \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"=r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1])\n", + " : \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[0]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[1]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[2]), \"r\"(((unsigned *)((half*)A_reindex_pad_shared_dyn_warp + 0))[3]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)((half*)B_decode_reindex_shared_dyn_warp + 4))[1]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[0]), \"r\"(((unsigned *)(C_reindex_pad_shared_dyn_warp + 4))[1]));\n", + " }\n", + " }\n", + " }\n", + " __syncthreads();\n", + " for (int local_id = 0; local_id < 8; local_id+=2) {\n", + "*((uint *)&(&(((half*)buf_dyn_shmem)[0]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_pad_shared_dyn_warp[0 + local_id]);\n", + "}\n", + ";\n", + " }\n", + " __syncthreads();\n", + " #pragma unroll\n", + " for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 1; ++ax0_ax1_ax2_ax3_ax4_fused_0) {\n", + " if (((((int)blockIdx.y) * 16) + (((int)threadIdx.x) >> 1)) < m) {\n", + " *(uint4*)(C + ((((((int)blockIdx.y) * 16384) + ((((int)threadIdx.x) >> 1) * 1024)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(((half*)buf_dyn_shmem) + (((int)threadIdx.x) * 8));\n", + " }\n", + " }\n", + "}\n", + "\n", + "\n", + "extern \"C\" void init() {\n", + " \n", + " cudaFuncSetAttribute(matmul_n1024k1024_f16xi4_tcx16x128x64w16x32_opt_m_16, cudaFuncAttributeMaxDynamicSharedMemorySize, 17408);\n", + "\n", + " cudaFuncSetAttribute(matmul_n1024k1024_f16xi4_tcx16x64x128w16x16_opt_m_32, cudaFuncAttributeMaxDynamicSharedMemorySize, 17408);\n", + "\n", + " cudaFuncSetAttribute(matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 17408);\n", + "\n", + " cudaFuncSetAttribute(matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_256, cudaFuncAttributeMaxDynamicSharedMemorySize, 21760);\n", + "\n", + " cudaFuncSetAttribute(matmul_n1024k1024_f16xi4_tcx32x16x256w16x16_opt_m_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 17408);\n", + "\n", + "}\n", + "\n", + "extern \"C\" void call(half* __restrict__ A, int8_t* __restrict__ B, half* __restrict__ C, int m, cudaStream_t stream=cudaStreamDefault) {\n", + " if (m == 0) return; \n", + " if (m <= 1) {\n", + " matmul_n1024k1024_f16xi4_simt_opt_m_1<<>>(A, B, C, m); \n", + " }\n", + " else if (m <= 16) {\n", + " matmul_n1024k1024_f16xi4_tcx16x128x64w16x32_opt_m_16<<>>(A, B, C, m); \n", + " }\n", + " else if (m <= 32) {\n", + " matmul_n1024k1024_f16xi4_tcx16x64x128w16x16_opt_m_32<<>>(A, B, C, m); \n", + " }\n", + " else if (m <= 64) {\n", + " matmul_n1024k1024_f16xi4_tcx32x16x256w16x16_opt_m_64<<>>(A, B, C, m); \n", + " }\n", + " else if (m <= 128) {\n", + " matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_128<<>>(A, B, C, m); \n", + " }\n", + " else if (m <= 256) {\n", + " matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_256<<>>(A, B, C, m); \n", + " }\n", + " else {\n", + " matmul_n1024k1024_f16xi4_tcx32x64x128w16x32_opt_m_256<<>>(A, B, C, m); \n", + " }\n", + "\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "matmul.hardware_aware_finetune(topk=20, parallel_build=True)\n", + "print(matmul.get_source())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/5.ladder_end2end.ipynb b/tutorials/5.ladder_end2end.ipynb new file mode 100644 index 000000000..50441da32 --- /dev/null +++ b/tutorials/5.ladder_end2end.ipynb @@ -0,0 +1,275 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f35f5227-1ded-44c4-985b-b75fe748876d", + "metadata": {}, + "source": [ + "# Ladder\n", + "\n", + "**Codebase:https://github.com/microsoft/BitBLAS/tree/osdi24_ladder_artifact**\n", + "```python\n", + "python ladder_from_onnx.py --prefix ./llama2_70b_single_layer/model.onnx\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5bfda684-1ca2-4999-8609-505a3c974b7e", + "metadata": {}, + "outputs": [], + "source": [ + "import ladder" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fe8e1349-c2f0-4cce-a39c-9a6bd0fff338", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/root/Ladder/3rdparty/tvm/python/tvm/target/target.py:397: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.\n", + " warnings.warn(\"Try specifying cuda arch by adding 'arch=sm_xx' to your target.\")\n", + "32\n", + "Output93, \n", + "add92, reshape91, \n", + "nn.dense90, \n", + "exp47, subtract46, \n", + "divide44, Constant43, reshape42, \n", + "reshape39, transpose38, reshape37, broadcast_to36, concatenate32, add34, strided_slice31, negative30, transpose40, expand_dims35, multiply33, strided_slice29, multiply28, \n", + "transpose27, reshape26, reshape25, \n", + "max45, \n", + "reshape23, add22, multiply21, concatenate20, strided_slice19, negative18, strided_slice17, multiply16, \n", + "nn.batch_matmul41, \n", + "reshape14, transpose15, reshape13, \n", + "nn.dense24, \n", + "reshape11, multiply10, cast9, multiply8, \n", + "divide7, sqrt6, add5, Constant4, mean3, Constant1, \n", + "reshape83, \n", + "nn.dense12, \n", + "reshape89, reshape87, multiply88, multiply85, sigmoid84, \n", + "multiply2, cast0, \n", + "sum48, \n", + "reshape52, cast51, cast50, divide49, \n", + "nn.dense53, \n", + "reshape60, reshape59, transpose61, broadcast_to58, expand_dims57, transpose56, reshape55, reshape54, \n", + "nn.batch_matmul62, \n", + "reshape66, transpose64, reshape65, reshape63, \n", + "nn.dense67, \n", + "add69, reshape68, \n", + "multiply72, cast70, \n", + "divide77, sqrt76, add75, Constant74, mean73, Constant71, \n", + "multiply80, cast79, multiply78, \n", + "reshape81, \n", + "nn.dense82, \n", + "nn.dense86, \n", + "dense is not optimized for this platform.\n", + "dense is not optimized for this platform.\n", + "batch_matmul is not optimized for this platform.\n", + "dense is not optimized for this platform.\n", + "batch_matmul is not optimized for this platform.\n", + "dense is not optimized for this platform.\n", + "dense is not optimized for this platform.\n", + "dense is not optimized for this platform.\n", + "dense is not optimized for this platform.\n", + "2024-10-24 16:45:31 [ladder:INFO]: Tuning ['cast_multiply_0', 'mean_add_sqrt_divide_1', 'multiply_cast_multiply_reshape_2']\n", + "2024-10-24 16:45:32 [ladder:INFO]: Tuning ['cast_multiply_0', 'mean_add_sqrt_divide_1']\n", + "2024-10-24 16:45:39 [ladder:INFO]: result: 0.003686400130391121 \n", + "2024-10-24 16:45:39 [ladder:INFO]: Tuning ['cast_multiply_0'] \n", + "2024-10-24 16:45:46 [ladder:INFO]: result: 0.0025298823602497578 \n", + "2024-10-24 16:45:46 [ladder:INFO]: Tuning ['mean_add_sqrt_divide_1'] \n", + "2024-10-24 16:45:49 [ladder:INFO]: result: 0.0032191998325288296 \n", + "2024-10-24 16:45:49 [ladder:INFO]: Tuning ['cast_multiply_0', 'mean_add_sqrt_divide_1', 'multiply_cast_multiply_reshape_2']\n", + "2024-10-24 16:45:50 [ladder:INFO]: Fusion group created: 0 ['cast_multiply_0', 'mean_add_sqrt_divide_1']\n", + "2024-10-24 16:45:50 [ladder:INFO]: Tuning ['multiply_cast_multiply_reshape_2', 'nn_dense_3', 'nn_dense_6', 'nn_dense_15']\n", + "2024-10-24 16:45:50 [ladder:INFO]: Tuning ['multiply_cast_multiply_reshape_2'] \n", + "2024-10-24 16:45:54 [ladder:INFO]: result: 0.0024840000551193953 \n", + "2024-10-24 16:45:54 [ladder:INFO]: Fusion group created: 1 ['multiply_cast_multiply_reshape_2']\n", + "2024-10-24 16:45:54 [ladder:INFO]: Tuning ['nn_dense_3', 'reshape_reshape_transpose_4']\n", + "2024-10-24 16:45:59 [ladder:INFO]: result: 0.15735039114952087 \n", + "2024-10-24 16:45:59 [ladder:INFO]: Tuning ['nn_dense_3', 'reshape_reshape_transpose_4']\n", + "2024-10-24 16:46:03 [ladder:INFO]: result: 0.20070399343967438 \n", + "2024-10-24 16:46:03 [ladder:INFO]: Tuning ['nn_dense_3'] \n", + "2024-10-24 16:46:08 [ladder:INFO]: result: 0.1441279947757721 \n", + "2024-10-24 16:46:08 [ladder:INFO]: Tuning ['reshape_reshape_transpose_4'] \n", + "2024-10-24 16:46:17 [ladder:INFO]: result: 0.00219345442019403 \n", + "2024-10-24 16:46:17 [ladder:INFO]: Tuning ['nn_dense_3'] \n", + "2024-10-24 16:46:17 [ladder:INFO]: Fusion group created: 2 ['nn_dense_3'] \n", + "2024-10-24 16:46:17 [ladder:INFO]: Tuning ['reshape_reshape_transpose_4', 'multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_reshape_5']\n", + "2024-10-24 16:46:17 [ladder:INFO]: Tuning ['reshape_reshape_transpose_4'] \n", + "2024-10-24 16:46:17 [ladder:INFO]: Fusion group created: 3 ['reshape_reshape_transpose_4']\n", + "2024-10-24 16:46:17 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_reshape_5', 'nn_dense_6', 'reshape_reshape_transpose_7', 'multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9']\n", + "2024-10-24 16:46:17 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_reshape_5']\n", + "2024-10-24 16:46:27 [ladder:INFO]: result: 0.002269866643473506 \n", + "2024-10-24 16:46:27 [ladder:INFO]: Fusion group created: 4 ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_reshape_5']\n", + "2024-10-24 16:46:27 [ladder:INFO]: Tuning ['nn_dense_6', 'reshape_reshape_transpose_7']\n", + "2024-10-24 16:46:32 [ladder:INFO]: result: 0.032153598964214325 \n", + "2024-10-24 16:46:32 [ladder:INFO]: Tuning ['nn_dense_6', 'reshape_reshape_transpose_7']\n", + "2024-10-24 16:46:35 [ladder:INFO]: result: 0.16315732896327972 \n", + "2024-10-24 16:46:35 [ladder:INFO]: Tuning ['nn_dense_6'] \n", + "2024-10-24 16:46:40 [ladder:INFO]: result: 0.006829713936895132 \n", + "2024-10-24 16:46:40 [ladder:INFO]: Tuning ['reshape_reshape_transpose_7'] \n", + "2024-10-24 16:46:46 [ladder:INFO]: result: 0.002187636448070407 \n", + "2024-10-24 16:46:46 [ladder:INFO]: Tuning ['nn_dense_6'] \n", + "2024-10-24 16:46:46 [ladder:INFO]: Fusion group created: 5 ['nn_dense_6'] \n", + "2024-10-24 16:46:46 [ladder:INFO]: Tuning ['reshape_reshape_transpose_7', 'multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8']\n", + "2024-10-24 16:46:47 [ladder:INFO]: Tuning ['reshape_reshape_transpose_7'] \n", + "2024-10-24 16:46:47 [ladder:INFO]: Fusion group created: 6 ['reshape_reshape_transpose_7']\n", + "2024-10-24 16:46:47 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9']\n", + "2024-10-24 16:46:51 [ladder:INFO]: result: 0.0027648000977933407 \n", + "2024-10-24 16:46:51 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8']\n", + "2024-10-24 16:47:01 [ladder:INFO]: result: 0.002321066800504923 \n", + "2024-10-24 16:47:01 [ladder:INFO]: Tuning ['nn_batch_matmul_9'] \n", + "2024-10-24 16:47:05 [ladder:INFO]: result: 0.0023893334437161684 \n", + "2024-10-24 16:47:05 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9', 'reshape_divide_10']\n", + "2024-10-24 16:47:10 [ladder:INFO]: result: 0.003276800038293004 \n", + "2024-10-24 16:47:10 [ladder:INFO]: Tuning ['reshape_divide_10'] \n", + "2024-10-24 16:47:14 [ladder:INFO]: result: 0.0024774738121777773 \n", + "2024-10-24 16:47:14 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9', 'reshape_divide_10', 'max_11', 'subtract_exp_12']\n", + "2024-10-24 16:47:19 [ladder:INFO]: result: 0.003276800038293004 \n", + "2024-10-24 16:47:19 [ladder:INFO]: Tuning ['max_11'] \n", + "2024-10-24 16:47:23 [ladder:INFO]: result: 0.0023503999691456556 \n", + "2024-10-24 16:47:23 [ladder:INFO]: Tuning ['subtract_exp_12'] \n", + "2024-10-24 16:47:27 [ladder:INFO]: result: 0.0025031110271811485 \n", + "2024-10-24 16:47:27 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9', 'reshape_divide_10', 'max_11', 'subtract_exp_12', 'sum_13', 'divide_cast_cast_reshape_14']\n", + "2024-10-24 16:47:33 [ladder:INFO]: result: 0.0033279999624937773 \n", + "2024-10-24 16:47:33 [ladder:INFO]: Tuning ['sum_13'] \n", + "2024-10-24 16:47:37 [ladder:INFO]: result: 0.0023713684640824795 \n", + "2024-10-24 16:47:37 [ladder:INFO]: Tuning ['divide_cast_cast_reshape_14'] \n", + "2024-10-24 16:47:40 [ladder:INFO]: result: 0.002368000103160739 \n", + "2024-10-24 16:47:40 [ladder:INFO]: Tuning ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9', 'reshape_divide_10', 'max_11', 'subtract_exp_12', 'sum_13', 'divide_cast_cast_reshape_14', 'nn_dense_15', 'reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16', 'nn_batch_matmul_17']\n", + "2024-10-24 16:47:41 [ladder:INFO]: Fusion group created: 7 ['multiply_strided_slice_negative_strided_slice_concatenate_multiply_add_expand_dims_broadcast_to_reshape_transpose_reshape_transpose_8', 'nn_batch_matmul_9', 'reshape_divide_10', 'max_11', 'subtract_exp_12', 'sum_13', 'divide_cast_cast_reshape_14']\n", + "2024-10-24 16:47:41 [ladder:INFO]: Tuning ['nn_dense_15', 'reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16']\n", + "2024-10-24 16:47:44 [ladder:INFO]: result: 0.16315732896327972 \n", + "2024-10-24 16:47:44 [ladder:INFO]: Tuning ['nn_dense_15', 'reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16']\n", + "2024-10-24 16:47:44 [ladder:ERROR]: Failed to get base tile: Traceback (most recent call last):\n", + " 3: TVMFuncCall\n", + " 2: tvm::runtime::PackedFuncObj::Extractor (tvm::runtime::Array const&, tvm::runtime::Array)>::AssignTypedLambda (*)(tvm::runtime::Array const&, tvm::runtime::Array)>(tvm::runtime::Map (*)(tvm::runtime::Array const&, tvm::runtime::Array), std::__cxx11::basic_string, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)\n", + " 1: tvm::arith::InverseAffineIterMap(tvm::runtime::Array const&, tvm::runtime::Array)\n", + " 0: tvm::arith::InverseAffineIterMapTransformer::operator()(tvm::runtime::Array const&, tvm::runtime::Array const&)\n", + " File \"/root/Ladder/3rdparty/tvm/src/arith/iter_affine_map.cc\", line 2128\n", + "TVMError: \n", + "---------------------------------------------------------------\n", + "An error occurred during the execution of TVM.\n", + "For more information, please see: https://tvm.apache.org/docs/errors.html\n", + "---------------------------------------------------------------\n", + " Check failed: (iter_map.size() == outputs.size()) is false: \n", + "2024-10-24 16:47:44 [ladder:INFO]: Tuning ['nn_dense_15'] \n", + "2024-10-24 16:47:44 [ladder:INFO]: Tuning ['reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16']\n", + "2024-10-24 16:47:54 [ladder:INFO]: result: 0.002187636448070407 \n", + "2024-10-24 16:47:54 [ladder:INFO]: Tuning ['nn_dense_15'] \n", + "2024-10-24 16:47:54 [ladder:INFO]: Fusion group created: 8 ['nn_dense_15'] \n", + "2024-10-24 16:47:54 [ladder:INFO]: Tuning ['reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16', 'nn_batch_matmul_17']\n", + "2024-10-24 16:48:42 [ladder:INFO]: result: 0.0023040000814944506 \n", + "2024-10-24 16:48:42 [ladder:INFO]: Tuning ['nn_batch_matmul_17'] \n", + "2024-10-24 16:48:50 [ladder:INFO]: result: 0.0023040000814944506 \n", + "2024-10-24 16:48:50 [ladder:INFO]: Tuning ['reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16', 'nn_batch_matmul_17', 'reshape_transpose_reshape_reshape_18']\n", + "2024-10-24 16:48:55 [ladder:INFO]: result: 0.002500923117622733 \n", + "2024-10-24 16:48:55 [ladder:INFO]: Tuning ['reshape_transpose_reshape_reshape_18']\n", + "2024-10-24 16:48:59 [ladder:INFO]: result: 0.0025599999353289604 \n", + "2024-10-24 16:48:59 [ladder:INFO]: Tuning ['reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16', 'nn_batch_matmul_17', 'reshape_transpose_reshape_reshape_18', 'nn_dense_19']\n", + "2024-10-24 16:48:59 [ladder:INFO]: Fusion group created: 9 ['reshape_reshape_transpose_expand_dims_broadcast_to_reshape_reshape_transpose_16', 'nn_batch_matmul_17', 'reshape_transpose_reshape_reshape_18']\n", + "2024-10-24 16:48:59 [ladder:INFO]: Tuning ['nn_dense_19', 'reshape_add_20'] \n", + "2024-10-24 16:49:03 [ladder:INFO]: result: 0.15933439135551453 \n", + "2024-10-24 16:49:03 [ladder:INFO]: Tuning ['nn_dense_19', 'reshape_add_20'] \n", + "2024-10-24 16:49:07 [ladder:INFO]: result: 0.2011733204126358 \n", + "2024-10-24 16:49:07 [ladder:INFO]: Tuning ['nn_dense_19'] \n", + "2024-10-24 16:49:07 [ladder:INFO]: Tuning ['reshape_add_20'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: result: 0.002297600032761693 \n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['nn_dense_19'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: Fusion group created: 10 ['nn_dense_19'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['reshape_add_20', 'cast_multiply_21', 'mean_add_sqrt_divide_22', 'multiply_cast_multiply_23', 'reshape_24', 'nn_dense_25', 'reshape_26', 'nn_dense_27', 'sigmoid_multiply_reshape_multiply_reshape_28', 'nn_dense_29', 'reshape_add_30']\n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['reshape_add_20'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: Fusion group created: 11 ['reshape_add_20'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['cast_multiply_21', 'mean_add_sqrt_divide_22', 'multiply_cast_multiply_23']\n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['cast_multiply_21', 'mean_add_sqrt_divide_22']\n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['cast_multiply_21'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['mean_add_sqrt_divide_22'] \n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['cast_multiply_21', 'mean_add_sqrt_divide_22', 'multiply_cast_multiply_23']\n", + "2024-10-24 16:49:12 [ladder:INFO]: Fusion group created: 12 ['cast_multiply_21', 'mean_add_sqrt_divide_22']\n", + "2024-10-24 16:49:12 [ladder:INFO]: Tuning ['multiply_cast_multiply_23', 'reshape_24']\n", + "2024-10-24 16:49:16 [ladder:INFO]: result: 0.002236444503068924 \n", + "2024-10-24 16:49:16 [ladder:INFO]: Tuning ['multiply_cast_multiply_23'] \n", + "2024-10-24 16:49:16 [ladder:INFO]: Tuning ['multiply_cast_multiply_23', 'reshape_24', 'nn_dense_25', 'nn_dense_27']\n", + "2024-10-24 16:49:16 [ladder:INFO]: Fusion group created: 13 ['multiply_cast_multiply_23', 'reshape_24']\n", + "2024-10-24 16:49:16 [ladder:INFO]: Tuning ['nn_dense_25', 'reshape_26'] \n", + "2024-10-24 16:49:21 [ladder:INFO]: result: 0.4925439953804016 \n", + "2024-10-24 16:49:21 [ladder:INFO]: Tuning ['nn_dense_25', 'reshape_26'] \n", + "2024-10-24 16:49:25 [ladder:INFO]: result: 0.5303786396980286 \n", + "2024-10-24 16:49:25 [ladder:INFO]: Tuning ['nn_dense_25'] \n", + "2024-10-24 16:49:25 [ladder:INFO]: Tuning ['nn_dense_25', 'reshape_26', 'nn_dense_27', 'sigmoid_multiply_reshape_multiply_reshape_28']\n", + "2024-10-24 16:49:32 [ladder:INFO]: result: 0.9988096356391907 \n", + "2024-10-24 16:49:32 [ladder:INFO]: Tuning ['nn_dense_27'] \n", + "2024-10-24 16:49:32 [ladder:INFO]: Tuning ['sigmoid_multiply_reshape_multiply_reshape_28']\n", + "2024-10-24 16:49:39 [ladder:INFO]: result: 0.002457600086927414 \n", + "2024-10-24 16:49:39 [ladder:INFO]: Fusion group created: 14 ['nn_dense_25', 'reshape_26']\n", + "2024-10-24 16:49:39 [ladder:INFO]: Tuning ['nn_dense_27', 'sigmoid_multiply_reshape_multiply_reshape_28']\n", + "2024-10-24 16:49:45 [ladder:INFO]: result: 0.5087040066719055 \n", + "2024-10-24 16:49:45 [ladder:INFO]: Tuning ['nn_dense_27', 'sigmoid_multiply_reshape_multiply_reshape_28']\n", + "2024-10-24 16:49:49 [ladder:INFO]: result: 0.5468416213989258 \n", + "2024-10-24 16:49:49 [ladder:INFO]: Tuning ['nn_dense_27'] \n", + "2024-10-24 16:49:49 [ladder:INFO]: Fusion group created: 15 ['nn_dense_27'] \n", + "2024-10-24 16:49:49 [ladder:INFO]: Tuning ['sigmoid_multiply_reshape_multiply_reshape_28', 'nn_dense_29']\n", + "2024-10-24 16:49:49 [ladder:INFO]: Tuning ['sigmoid_multiply_reshape_multiply_reshape_28']\n", + "2024-10-24 16:49:49 [ladder:INFO]: Fusion group created: 16 ['sigmoid_multiply_reshape_multiply_reshape_28']\n", + "2024-10-24 16:49:49 [ladder:INFO]: Tuning ['nn_dense_29', 'reshape_add_30'] \n", + "2024-10-24 16:49:55 [ladder:INFO]: result: 0.5228544473648071 \n", + "2024-10-24 16:49:55 [ladder:INFO]: Tuning ['nn_dense_29', 'reshape_add_30'] \n", + "2024-10-24 16:49:58 [ladder:INFO]: result: 0.6908586621284485 \n", + "2024-10-24 16:49:58 [ladder:INFO]: Tuning ['nn_dense_29'] \n", + "2024-10-24 16:50:03 [ladder:INFO]: result: 0.4927999973297119 \n", + "2024-10-24 16:50:03 [ladder:INFO]: Tuning ['reshape_add_30'] \n", + "2024-10-24 16:50:03 [ladder:INFO]: Tuning ['nn_dense_29'] \n", + "2024-10-24 16:50:03 [ladder:INFO]: Fusion group created: 17 ['nn_dense_29'] \n", + "2024-10-24 16:50:03 [ladder:INFO]: Tuning ['reshape_add_30'] \n", + "2024-10-24 16:50:03 [ladder:INFO]: Fusion group created: 18 ['reshape_add_30'] \n", + "Processing: 100%|███████████████████████████████| 47/47 [04:31<00:00, 5.78s/it]\n", + "Execution time summary:\n", + " mean (ms) median (ms) max (ms) min (ms) std (ms) \n", + " 1.8222 1.8222 1.8222 1.8222 0.0000 \n", + " \n" + ] + } + ], + "source": [ + "!python ladder_from_onnx.py --prefix /root/Ladder/artifact/models/llama_70b/llama2_70b_layer1_seq1_bs1/model.onnx" + ] + }, + { + "cell_type": "markdown", + "id": "61f57c6f-35fe-4401-b41e-363dd5f0ef0e", + "metadata": {}, + "source": [ + "## Video\n", + "\n", + "https://1drv.ms/v/c/4c1511b24254d525/Ebn1ue4ig6pJnROQCjPOyvMBCiAvp5JGlM2AGIhgew-bGw?e=YWh4Xc" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/6.tile-language.ipynb b/tutorials/6.tile-language.ipynb new file mode 100644 index 000000000..7826f0743 --- /dev/null +++ b/tutorials/6.tile-language.ipynb @@ -0,0 +1,475 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "50a965a4-79a6-4a1b-96be-f8c857c16d46", + "metadata": {}, + "source": [ + "# Tile Language in BitBLAS\n", + "\n", + "More flexiable, More Efficient Tile Programming Languange compared with Triton\n", + "\n", + "## Features\n", + "\n", + "- **Simplified Syntax**: Write GPU kernels with a more straightforward and expressive syntax.\n", + "- **High Performance**: Achieve performance comparable to manually optimized implementations.\n", + "- **Advanced Operations**: Support for complex operations like convolutions, flash-attention, and normalizations.\n", + "- **Compatibility**: Works with modern CUDA architectures.\n", + "\n", + "## OP Examples\n", + "\n", + "- [Matrix Multiplication](#quick-start)\n", + "- [Flash Attention](#flash-attention)\n", + "- [Dequantization GEMM](#dequantization-gemm)\n", + "- [RetNet](#retina-net)\n", + "- [MAMBA](#mamba)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ff371661-da38-4b37-9e77-330d812aa7e7", + "metadata": {}, + "outputs": [], + "source": [ + "# Import Tile Language from bitblas\n", + "from bitblas import tvm as tvm\n", + "from tvm import tl\n", + "import tvm.tl.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "cb42b2fa-0d66-4f22-8bd9-7561dfdde15e", + "metadata": {}, + "source": [ + "## Get Started with a GEMM Example" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9af2499b-bca8-4ab0-bb2a-4926de21fde1", + "metadata": {}, + "outputs": [], + "source": [ + "M = N = K = 256\n", + "\n", + "A_shape = (M, K)\n", + "B_shape = (N, K)\n", + "C_shape = (M, N)\n", + "in_dtype = out_dtype = accum_dtype = \"float16\"\n", + "\n", + "block_M = block_N = 128\n", + "block_K = 32\n", + "threads = 128\n", + "num_stages = 2\n", + "\n", + "A_shared_shape = (block_M, block_K)\n", + "B_shared_shape = (block_N, block_K)\n", + "@T.prim_func\n", + "def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer(\n", + " (M, N), out_dtype)):\n", + " with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):\n", + " A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n", + " B_shared = T.alloc_shared(B_shared_shape, in_dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), accum_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):\n", + " T.copy(A[by * block_M, k * block_K], A_shared)\n", + " T.copy(B[bx * block_N, k * block_K], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local, transpose_A=False, transpose_B=True)\n", + " T.copy(C_local, C[by * block_M, bx * block_N])\n", + "\n", + "func = main" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "5503e7a5-3ccb-4c7b-87a2-d6a9d66dea9e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#include \n", + "#include \n", + "#include \n", + "#include \n", + "#include \n", + "\n", + "extern \"C\" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {\n", + " extern __shared__ __align__(1024) uchar buf_dyn_shmem[];\n", + " half_t C_local[128];\n", + " #pragma unroll\n", + " for (int i = 0; i < 64; ++i) {\n", + " *(uint1*)(C_local + (i * 2)) = make_uint1(__pack_half2(half_t(0.000000e+00f), half_t(0.000000e+00f)));\n", + " }\n", + " #pragma unroll\n", + " for (int i_1 = 0; i_1 < 4; ++i_1) {\n", + " tl::cp_async_gs<16>(buf_dyn_shmem+((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((int)blockIdx.y) * 32768) + (i_1 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)threadIdx.x) & 3) * 8)));\n", + " }\n", + " #pragma unroll\n", + " for (int i_2 = 0; i_2 < 4; ++i_2) {\n", + " tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_2 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), B+((((((int)blockIdx.x) * 32768) + (i_2 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)threadIdx.x) & 3) * 8)));\n", + " }\n", + " tl::cp_async_commit();\n", + " for (int k = 0; k < 7; ++k) {\n", + " #pragma unroll\n", + " for (int i_3 = 0; i_3 < 4; ++i_3) {\n", + " tl::cp_async_gs<16>(buf_dyn_shmem+(((((((k + 1) & 1) * 8192) + (i_3 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((((int)blockIdx.y) * 32768) + (i_3 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 32));\n", + " }\n", + " #pragma unroll\n", + " for (int i_4 = 0; i_4 < 4; ++i_4) {\n", + " tl::cp_async_gs<16>(buf_dyn_shmem+((((((((k + 1) & 1) * 8192) + (i_4 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), B+((((((((int)blockIdx.x) * 32768) + (i_4 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 32));\n", + " }\n", + " tl::cp_async_commit();\n", + " tl::cp_async_wait<1>();\n", + " __syncthreads();\n", + " tl::gemm_ss<128, 128, 32, 2, 2, 0, 1>((&(((half_t*)buf_dyn_shmem)[((k & 1) * 4096)])), (&(((half_t*)buf_dyn_shmem)[(((k & 1) * 4096) + 8192)])), (&(C_local[0])));\n", + " }\n", + " tl::cp_async_wait<0>();\n", + " __syncthreads();\n", + " tl::gemm_ss<128, 128, 32, 2, 2, 0, 1>((&(((half_t*)buf_dyn_shmem)[4096])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(C_local[0])));\n", + " #pragma unroll\n", + " for (int i_5 = 0; i_5 < 64; ++i_5) {\n", + " *(uint1*)(C + (((((((((((int)blockIdx.y) * 32768) + (((i_5 & 7) >> 1) * 8192)) + (((((int)threadIdx.x) & 63) >> 5) * 4096)) + ((i_5 & 1) * 2048)) + (((((int)threadIdx.x) & 31) >> 2) * 256)) + (((int)blockIdx.x) * 128)) + ((i_5 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(uint1*)(C_local + (i_5 * 2));\n", + " }\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "rt_mod, params = tl.lower(func)\n", + "print(rt_mod.imported_modules[0].get_source())" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "212dfd69-b7d6-4a37-a4f3-07241c18488e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Assert Pass\n" + ] + } + ], + "source": [ + "mod = tl.Profiler(rt_mod, params, [2], tl.TensorSupplyType.Integer)\n", + "\n", + "def ref_program(A, B):\n", + " import torch\n", + " B = B.T\n", + " C = torch.matmul(A.to(torch.float), B.to(torch.float))\n", + " C = C.to(torch.__getattribute__(out_dtype))\n", + " return C\n", + "\n", + "mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)\n", + "print(\"Assert Pass\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "0b4fd8fd-c2ca-4740-8c14-e3646993c9b4", + "metadata": {}, + "source": [ + "## Manipulate Data Layout and Pipeline" + ] + }, + { + "cell_type": "markdown", + "id": "aec3ac4e-385f-44cb-b439-2357901e1d86", + "metadata": {}, + "source": [ + "TL also provide interface for users to manupulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5da51fe1-e70d-447a-9097-6df9af6e16d6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n" + ] + } + ], + "source": [ + "def matmul(M, N, K, block_M, block_N, block_K, dtype=\"float16\", accum_dtype=\"float\"):\n", + " @T.prim_func\n", + " def main(\n", + " A: T.Buffer((M, K), dtype),\n", + " B: T.Buffer((K, N), dtype),\n", + " C: T.Buffer((M, N), dtype),\n", + " ):\n", + " with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), accum_dtype)\n", + "\n", + " \n", + " # Apply memory layout optimizations\n", + " # Or you can define your own memory layout\n", + " T.annotate_layout({\n", + " A_shared: make_swizzle_layout(A_shared),\n", + " B_shared: make_swizzle_layout(B_shared),\n", + " })\n", + "\n", + " # Enable rasterization for better L2 Cache Locality\n", + " T.use_swizzle(panel_size=10, enable=enable_rasterization)\n", + "\n", + " # Clear the local buffer\n", + " T.clear(C_local)\n", + "\n", + " # Auto pipeline the computation\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[by * block_M, k * block_K], A_shared)\n", + "\n", + " # Instead of using\n", + " # T.copy(B[k * block_K, bx * block_N], B_shared)\n", + " # we can also use Parallel to auto map the thread\n", + " # bindings and vectorize the copy operation.\n", + " for k, j in T.Parallel(block_K, block_N):\n", + " B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]\n", + "\n", + " T.gemm(A_shared, B_shared, C_local)\n", + "\n", + " T.copy(C_local, C[by * block_M, bx * block_N])\n", + "\n", + " return main" + ] + }, + { + "cell_type": "markdown", + "id": "6e81ffc7-159a-4ae1-8719-2a0c03142f22", + "metadata": {}, + "source": [ + "## Implement Dequantize GEMM with simple Syntax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0bae1971-9aa5-4db5-8a3f-f35aa495e4d8", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'T' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;129m@T\u001b[39m\u001b[38;5;241m.\u001b[39mprim_func\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdequant_matmul\u001b[39m(\n\u001b[1;32m 3\u001b[0m A: T\u001b[38;5;241m.\u001b[39mBuffer(A_shape, in_dtype),\n\u001b[1;32m 4\u001b[0m B: T\u001b[38;5;241m.\u001b[39mBuffer(B_shape, storage_dtype),\n\u001b[1;32m 5\u001b[0m Ct: T\u001b[38;5;241m.\u001b[39mBuffer((N, M), out_dtype),\n\u001b[1;32m 6\u001b[0m ):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;241m.\u001b[39mKernel(T\u001b[38;5;241m.\u001b[39mceildiv(N, block_N), T\u001b[38;5;241m.\u001b[39mceildiv(M, block_M), threads\u001b[38;5;241m=\u001b[39mthreads) \u001b[38;5;28;01mas\u001b[39;00m (bx, by):\n\u001b[1;32m 8\u001b[0m A_shared \u001b[38;5;241m=\u001b[39m T\u001b[38;5;241m.\u001b[39malloc_shared(A_shared_shape, in_dtype)\n", + "\u001b[0;31mNameError\u001b[0m: name 'T' is not defined" + ] + } + ], + "source": [ + "@T.prim_func\n", + "def dequant_matmul(\n", + " A: T.Buffer(A_shape, in_dtype),\n", + " B: T.Buffer(B_shape, storage_dtype),\n", + " Ct: T.Buffer((N, M), out_dtype),\n", + "):\n", + " with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):\n", + " A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n", + " B_shared = T.alloc_shared(B_shared_shape, storage_dtype)\n", + " B_local = T.alloc_fragment(B_shared_shape, storage_dtype)\n", + " B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)\n", + " Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)\n", + "\n", + " T.clear(Ct_local)\n", + " for k in T.Pipelined(\n", + " T.ceildiv(K, block_K), \n", + " num_stages=num_stages\n", + " ):\n", + " T.copy(A[by * block_M, k * block_K], A_shared)\n", + " T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)\n", + " T.copy(B_shared, B_local)\n", + " for i, j in T.Parallel(block_N, block_K):\n", + " B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert(\"int\", 8)(\n", + " num_bits,\n", + " B_local[i, j // 2],\n", + " j % 2,\n", + " dtype=in_dtype,\n", + " )\n", + " T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)\n", + " T.copy(Ct_local, Ct[bx * block_N, by * block_M])" + ] + }, + { + "cell_type": "markdown", + "id": "48f49319-ad63-4673-8130-b010dc8ba22e", + "metadata": {}, + "source": [ + "## If you want fine-grained control over dequantization at the thread leve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7e821e6-4517-47b7-8c68-b145eb513d99", + "metadata": {}, + "outputs": [], + "source": [ + "@T.prim_func\n", + "def main(\n", + " A: T.Buffer(A_shape, in_dtype),\n", + " B: T.Buffer(B_shape, storage_dtype),\n", + " C: T.Buffer((M, N), out_dtype),\n", + "):\n", + " with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):\n", + " A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n", + " B_shared = T.alloc_shared(B_shared_shape, storage_dtype)\n", + " B_local = T.alloc_local([local_size_compressed], storage_dtype)\n", + " B_dequantize_local = T.alloc_local([local_size], in_dtype)\n", + " B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), accum_dtype)\n", + "\n", + " tx = T.thread_binding(0, threads, thread=\"threadIdx.x\")\n", + "\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):\n", + " T.copy(A[by * block_M, k * block_K], A_shared)\n", + " T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)\n", + "\n", + " for i in T.serial(block_N * block_K // num_elems_per_byte //\n", + " (threads * local_size_compressed)):\n", + " for v in T.vectorized(0, local_size_compressed):\n", + " index = i * threads * local_size_compressed + tx * local_size_compressed + v\n", + " vi = index // (block_K // num_elems_per_byte)\n", + " vj = index % (block_K // num_elems_per_byte)\n", + " B_local[v] = B_shared[vi, vj]\n", + " for v in T.serial(0, local_size):\n", + " B_dequantize_local[v] = _tir_packed_to_unsigned_convert(\n", + " storage_type, storage_nbit)(\n", + " num_bits,\n", + " B_local[v // num_elems_per_byte],\n", + " v % num_elems_per_byte,\n", + " dtype=in_dtype,\n", + " )\n", + " for v in T.vectorized(0, local_size):\n", + " index = i * threads * local_size + tx * local_size + v\n", + " vi = index // block_K\n", + " vj = index % block_K\n", + " B_dequantize_shared[vi, vj] = B_dequantize_local[v]\n", + "\n", + " T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)\n", + "\n", + " T.copy(C_local, C[by * block_M, bx * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "37d65188-ce68-4056-bbc3-eaa8ab3b2316", + "metadata": {}, + "source": [ + "## Flash Attention V3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72f095d2-9978-43c4-ab63-21264e97adb4", + "metadata": {}, + "outputs": [], + "source": [ + "@T.prim_func\n", + "def flash_attention_v3(\n", + " Q: T.Buffer(shape, dtype),\n", + " K: T.Buffer(shape, dtype),\n", + " V: T.Buffer(shape, dtype),\n", + " Output: T.Buffer(shape, dtype),\n", + "):\n", + " with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):\n", + " Q_shared = T.alloc_shared([block_M, dim], dtype)\n", + " K_shared = T.alloc_shared([block_N, dim], dtype)\n", + " V_shared = T.alloc_shared([block_N, dim], dtype)\n", + " acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)\n", + " acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)\n", + " acc_o = T.alloc_fragment([block_M, dim], accum_dtype)\n", + " scores_max = T.alloc_fragment([block_M], accum_dtype)\n", + " scores_max_prev = T.alloc_fragment([block_M], accum_dtype)\n", + " scores_scale = T.alloc_fragment([block_M], accum_dtype)\n", + " scores_sum = T.alloc_fragment([block_M], accum_dtype)\n", + " logsum = T.alloc_fragment([block_M], accum_dtype)\n", + "\n", + " T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})\n", + " T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)\n", + " T.fill(acc_o, 0)\n", + " T.fill(logsum, 0)\n", + " T.fill(scores_max, -T.infinity(accum_dtype))\n", + " loop_range = (\n", + " T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)\n", + " )\n", + " for k in T.Pipelined(loop_range, num_stages=num_stages):\n", + " T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)\n", + " if is_casual:\n", + " for i, j in T.Parallel(block_M, block_N):\n", + " acc_s[i, j] = T.if_then_else(\n", + " bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)\n", + " )\n", + " else:\n", + " T.clear(acc_s)\n", + " T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)\n", + " T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)\n", + " for i, j in T.Parallel(block_M, dim):\n", + " acc_s[i, j] *= scale\n", + " T.copy(scores_max, scores_max_prev)\n", + " T.fill(scores_max, -T.infinity(accum_dtype))\n", + " T.reduce_max(acc_s, scores_max, dim=1, clear=False)\n", + " for i in T.Parallel(block_M):\n", + " scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])\n", + " for i, j in T.Parallel(block_M, dim):\n", + " acc_o[i, j] *= scores_scale[i]\n", + " for i, j in T.Parallel(block_M, block_N):\n", + " acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])\n", + " T.copy(acc_s, acc_s_cast)\n", + " T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)\n", + " T.reduce_sum(acc_s, scores_sum, dim=1)\n", + " for i in T.Parallel(block_M):\n", + " logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]\n", + " for i, j in T.Parallel(block_M, dim):\n", + " acc_o[i, j] /= logsum[i]\n", + " T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/img/AutoTensorization.png b/tutorials/img/AutoTensorization.png new file mode 100644 index 000000000..37afb9771 Binary files /dev/null and b/tutorials/img/AutoTensorization.png differ diff --git a/tutorials/img/DynamicTuning.png b/tutorials/img/DynamicTuning.png new file mode 100644 index 000000000..3e1bcee48 Binary files /dev/null and b/tutorials/img/DynamicTuning.png differ diff --git a/tutorials/img/FastDequantization.png b/tutorials/img/FastDequantization.png new file mode 100644 index 000000000..230628424 Binary files /dev/null and b/tutorials/img/FastDequantization.png differ diff --git a/tutorials/img/FastDequantization_EXP.png b/tutorials/img/FastDequantization_EXP.png new file mode 100644 index 000000000..c676e46d8 Binary files /dev/null and b/tutorials/img/FastDequantization_EXP.png differ diff --git a/tutorials/img/roller.png b/tutorials/img/roller.png new file mode 100644 index 000000000..8bf38d2d6 Binary files /dev/null and b/tutorials/img/roller.png differ diff --git a/tutorials/ladder_from_onnx.py b/tutorials/ladder_from_onnx.py new file mode 100644 index 000000000..ce07e7265 --- /dev/null +++ b/tutorials/ladder_from_onnx.py @@ -0,0 +1,165 @@ +import argparse +import os.path as osp +import numpy as np +import onnx +import ladder +import tvm +from tvm import relay +from tvm.contrib.debugger import debug_executor +from tvm.contrib import graph_executor +from ladder.utils import write_mod +import os +import torch +import logging + +ladder.set_log_level(logging.INFO) + +# get file name and remove the suffix +fname = os.path.basename(__file__) +fname = os.path.splitext(fname)[0] +# create log path +log_path = "progress/e2e/" + fname + +parser = argparse.ArgumentParser() +parser.add_argument('--prefix', type=str, default='llama2-70b') +parser.add_argument('--arch', type=str, default="cuda") +parser.add_argument('--cublas', action="store_true") +parser.add_argument('--cudnn', action="store_false") +parser.add_argument('--nhwc', action="store_false") +parser.add_argument('--async_propagation', action="store_true", help="Use async propagation and async instructions, which should be only enabled on data center GPUs with async copy instructions.", default=False) +parser.add_argument("--prebuilt_path", type=str, default=None, help="Path to the prebuilt model. If set, the script will run from the prebuilt model.") +parser.add_argument("--fast_decoding", action="store_false", help="Enable fast decoding mode.") + +args = parser.parse_args() + +def run(prefix, arch, async_propagate): + if ".onnx" in prefix: + onnx_model = onnx.load(prefix) + else: + onnx_model = onnx.load(osp.join(prefix, "model.onnx")) + mod, params = relay.frontend.from_onnx( + onnx_model, convert_config={"use_welder_matmul": False}) + write_mod(mod, log_path, "load_from_onnx") + + if args.nhwc: + # must convert bias_add -> broadcast_add to propogate the layout + mod = relay.transform.InferType()(mod) + mod = relay.transform.CanonicalizeOps()(mod) + write_mod(mod, log_path, "CanonicalizeOps") + mod = relay.transform.ConvertLayout( + {"nn.conv2d": ["NHWC", "default"]})(mod) + write_mod(mod, log_path, "ConvertLayout") + mod = relay.transform.FoldConstant()(mod) + write_mod(mod, log_path, "FoldConstant") + mod = ladder.relay.transform.WelderExprRewrite(enable_softmax=True)(mod) + write_mod(mod, log_path, "expr_rewrite") + + if args.cudnn: + from tvm.relay.op.contrib.cudnn import pattern_table + seq = tvm.transform.Sequential( + [ + relay.transform.InferType(), + relay.transform.MergeComposite(pattern_table()), + relay.transform.AnnotateTarget("cudnn"), + relay.transform.PartitionGraph(bind_constants=False), + relay.transform.InferType(), + ] + ) + mod = seq(mod) + + mod = ladder.relay.transform.LadderConvImplicitGemm( + use_async_propagation=async_propagate)(mod) + write_mod(mod, log_path, "LadderConvImplicitGemm") + mod = ladder.relay.transform.LadderPerfectGemmTransform( + use_async_propagation=async_propagate)(mod) + write_mod(mod, log_path, "LadderPerfectGemmTransform") + mod = ladder.relay.transform.WelderConvImplicitGemm()(mod) + write_mod(mod, log_path, "WelderConvImplicitGemm") + mod = relay.transform.FoldConstant()(mod) + write_mod(mod, log_path, "FoldConstant") + mod = relay.transform.EliminateCommonSubexpr()(mod) + write_mod(mod, log_path, "EliminateCommonSubexpr") + mod = ladder.relay.transform.LadderRewriteInceptionLayout()(mod) + write_mod(mod, log_path, "LadderRewriteInceptionLayout") + if args.cublas: + from tvm.relay.op.contrib.cublas import pattern_table + seq = tvm.transform.Sequential( + [ + relay.transform.InferType(), + relay.transform.MergeComposite(pattern_table()), + relay.transform.AnnotateTarget("cublas"), + relay.transform.PartitionGraph(bind_constants=False), + relay.transform.InferType(), + ] + ) + mod = seq(mod) + write_mod(mod, log_path, "cublas_partition") + mod = relay.transform.DeadCodeElimination()(mod) + write_mod(mod, log_path, "DeadCodeElimination") + mod = relay.transform.FoldConstant()(mod) + write_mod(mod, log_path, "FoldConstant") + mod = relay.transform.EliminateCommonSubexpr()(mod) + write_mod(mod, log_path, "EliminateCommonSubexpr") + mod = ladder.relay.transform.WelderFuseOps()(mod) + write_mod(mod, log_path, "WelderFuseOps") + mod = ladder.relay.transform.AnnotateLadderTensorCore(arch=arch)(mod) + write_mod(mod, log_path, "AnnotateLadderTensorCore") + mod = ladder.relay.transform.AnnotateTensorCore()(mod) + write_mod(mod, log_path, "AnnotateWelderTensorCore") + if args.fast_decoding: + mod = ladder.relay.transform.AnnotateFastDecoding()(mod) + write_mod(mod, log_path, "AnnotateFastDecoding") + + mod = ladder.relay.transform.WelderTunePass(arch, topk=40,save_perf_log="./debug_group_info")(mod) + write_mod(mod, log_path, "WelderTunePass") + + factory = relay.build(mod, arch.target, params=params) + + with open(osp.join(log_path, "graph.json"), "w") as f: + f.write(factory.get_graph_json()) + with open(osp.join(log_path, "graph.params"), "wb") as f_params: + f_params.write(tvm.runtime.save_param_dict(factory.get_params())) + lib = ladder.relay.update_lib( + factory.get_lib(), arch, osp.join(log_path, "model.so")) + + rt_mod = graph_executor.create(factory.get_graph_json(), lib, tvm.cuda(0)) + rt_mod.set_input(**factory.get_params()) + print(rt_mod.benchmark(tvm.cuda(0), min_repeat_ms=500, end_to_end=False)) + + +def run_from_prebuilt(prefix, arch): + lib_path = os.path.join(prefix, "model.so") + with open(os.path.join(prefix, "graph.json")) as f: + graph_json = f.read() + with open(os.path.join(prefix, "graph.params"), "rb") as f_params: + params = f_params.read() + loaded_lib = tvm.runtime.load_module(lib_path) + module = debug_executor.create(graph_json, loaded_lib, tvm.cuda(0)) + module.load_params(params) + print(module.benchmark(tvm.cuda(0), min_repeat_ms=500, end_to_end=False)) + module.run() + # dummy input + input_shape = (1, 1) + dtype = 'int64' + input_data = tvm.nd.array(np.ones(input_shape).astype(dtype)) + module.set_input("input.1", input_data) + module.run() + outputs = [] + for i in range(module.get_num_outputs()): + out = module.get_output(i).asnumpy() + outputs.append(out) + print(outputs) + + +if __name__ == "__main__": + path = args.prefix + arch = ladder.arch.__getattribute__(args.arch)() + async_propagate = args.async_propagation + if arch.compute_capability == "80": + async_propagate = True + # path = "/home/t-leiwang/ladder_workspace/Ladder/artifact/QuickStart/qmodels/opt-125m-4bit/qmodel_b1s1/qmodel_b1s1.onnx" + prebuilt_path = args.prebuilt_path + if prebuilt_path: + run_from_prebuilt(prebuilt_path, arch) + else: + run(path, arch, async_propagate)