diff --git a/3rdparty/tvm b/3rdparty/tvm index af0b40391..5a8b30a0b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit af0b403916c853160df4ee3d046bdd4182c1ea44 +Subproject commit 5a8b30a0be08ccdf4335dd273ceb9eff974ada9f diff --git a/bitblas/common.py b/bitblas/common.py index b2023f7b8..f1b0aa361 100644 --- a/bitblas/common.py +++ b/bitblas/common.py @@ -5,4 +5,4 @@ BITBLAS_DEFAULT_CACHE_PATH = os.path.expanduser("~/.cache/bitblas") -MAX_ERROR_MESSAGE_LENGTH = 200 +MAX_ERROR_MESSAGE_LENGTH = 500 diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index f5796f589..ac057a600 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1170,7 +1170,7 @@ def sch_shared_memory_prefetch_with_config( """ weight_transform_kind = config.intrin_info.weight_transform_kind - if weight_transform_kind == TransformKind.LDMatrixTransform and config.block_reduction_depth is not None: + if weight_transform_kind == TransformKind.LDMatrixTransform: return self.sch_warp_memory_prefetch_with_config(func, config) is_cross_thread_reduce = ( @@ -1826,8 +1826,7 @@ def check_dequantize_info(dequantize_info): block_col_warps = config.block[1] // warp_col_tiles stage = config.pipeline_stage use_async = config.use_async - assert (config.block_reduction_depth is not None), "block_reduction_depth is required" - reduce_k = config.block_reduction_depth + reduce_k = config.block_reduction_depth if config.block_reduction_depth is not None else 1 chunk = config.rstep[0] // reduce_k micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] @@ -2005,9 +2004,12 @@ def get_param_indices( i0, i1, i2, i3 = sch.split(i, factors=i_factors) j0, j1, j2, j3 = sch.split(j, factors=j_factors) k0, k1 = sch.split(k, k_factors) - k0, kr = sch.split(k0, [None, reduce_k]) + if reduce_k > 1: + k0, kr = sch.split(k0, [None, reduce_k]) - sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) + sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) + else: + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) block_idy = sch.fuse(i0, j0) block_idx = sch.fuse(i1, j1) @@ -2017,12 +2019,13 @@ def get_param_indices( sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") - thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) - sch.bind(thread_idy, "threadIdx.y") - - # Put the thread binding after the shared memory prefetch - # Otherwise there's a axis missing bug behind tvm - sch.bind(kr, "threadIdx.z") + if reduce_k > 1: + thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) + sch.bind(thread_idy, "threadIdx.y") + sch.bind(kr, "threadIdx.z") + else: + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: @@ -2054,13 +2057,20 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): ndim = len(sch.get(block_read).iter_vars) fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - f_0, f_r, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, reduce_k, num_ty, num_tz, warp_size, vec_len]) + if reduce_k > 1: + f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) + sch.bind(f_3, "threadIdx.x") + f_0 = f_1 = sch.fuse(f_0, f_1) + sch.bind(f_0, "threadIdx.y") + sch.bind(f_r, "threadIdx.z") + else: + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) + sch.bind(f_3, "threadIdx.x") + sch.bind(f_1, "threadIdx.z") + sch.bind(f_0, "threadIdx.y") - sch.bind(f_3, "threadIdx.x") - f_1 = f_2 = sch.fuse(f_1, f_2) - sch.bind(f_1, "threadIdx.y") - sch.bind(f_r, "threadIdx.z") sch.vectorize(f_4) sch.unroll(f_0) sch.annotate(f_0, "pragma_unroll_explicit", False) @@ -2085,7 +2095,8 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): # Put the thread binding after the shared memory prefetch # Otherwise there's a axis missing bug behind tvm - sch.bind(kr, "threadIdx.z") + if reduce_k > 1: + sch.bind(kr, "threadIdx.z") # create read cache to load matrix from shared memory to wmma fragments A_mat = sch.cache_read(block_outer, 0, "warp") sch.compute_at(A_mat, k1, preserve_unit_loops=True) @@ -2127,18 +2138,24 @@ def get_idx(): ndim = len(sch.get(B_shared).iter_vars) _ = sch.fuse(*sch.get_loops(B_shared)[-ndim:]) - _bind_thread_based_with_block_reduce_on_config( - sch, - B_shared, - num_ty, - num_tz, - warp_size, - reduce_k, - ) + if reduce_k > 1: + _bind_thread_based_with_block_reduce_on_config( + sch, + B_shared, + num_ty, + num_tz, + warp_size, + reduce_k, + ) + else: + _bind_thread_based_on_config(sch, B_shared, num_ty, num_tz, warp_size) return B_dequantized_mat B_dequantized_mat = warp_memory_dequantize() - + # Put the thread binding after the shared memory prefetch + # Otherwise there's a axis missing bug behind tvm + if reduce_k > 1: + sch.bind(kr, "threadIdx.z") # create write cache to store matrix from wmma fragments to shared memory and global memory if cache_write_required: accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) @@ -2159,9 +2176,13 @@ def get_idx(): sch.get_loops(store)[-6], preserve_unit_loops=True, ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + f0, f1, f2 = sch.split( + fused, + factors=[ + None, warp_size, + get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + ]) sch.bind(f1, "threadIdx.x") sch.vectorize(f2) sch.unroll(f0) @@ -2205,6 +2226,7 @@ def get_idx(): j0, j1 = sch.split(j, factors=[None, b_lr[1]]) sch.reorder(i0, j0, i1, j1) _ = sch.blockize(i1) + vec_len = get_coalesced_veclen(sch.get(B_dequantized_mat)) sch.transform_block_layout( B_dequantized_mat, lambda i, j: ((i * b_lr[1] + j) // vec_len, (i * b_lr[1] + j) % vec_len)) diff --git a/integration/BitNet/benchmark_inference_latency.py b/integration/BitNet/benchmark_inference_latency.py index 694a824e8..1d711d3c0 100644 --- a/integration/BitNet/benchmark_inference_latency.py +++ b/integration/BitNet/benchmark_inference_latency.py @@ -11,6 +11,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) + def profile(model, input_data): import time @@ -34,26 +35,25 @@ def get_runtime(num_repeats=1): times = get_runtime(num_repeats) return np.mean(times) + def main(): model = BitnetForCausalLM.from_pretrained( '1bitLLM/bitnet_b1_58-3B', device_map='auto', - low_cpu_mem_usage=True, + low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() with torch.no_grad(): - model._post_process_weights() - - benchmark_sets = [ - (1, 1), - (128, 1), - (1, 2048) - ] + model.quantize() + model = torch.compile(model) + + benchmark_sets = [(1024, 1), (1, 2048)] for batch_size, seq_len in benchmark_sets: input_id = torch.ones(batch_size, seq_len).long().cuda() latency = profile(model, input_id) print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}") + if __name__ == '__main__': main() diff --git a/integration/pytorch/bitblas_linear.py b/integration/pytorch/bitblas_linear.py index c315cf6c4..041be483a 100644 --- a/integration/pytorch/bitblas_linear.py +++ b/integration/pytorch/bitblas_linear.py @@ -24,7 +24,7 @@ def error_raiser_bitblas(*args, **kwargs): autogptq_bitblas_cuda = bitblas_import_exception from bitblas.utils import auto_detect_nvidia_target # noqa: E402 -from bitblas.ops.matmul import MatmulConfig, Matmul # noqa: E402 +from bitblas.ops import MatmulConfig, Matmul # noqa: E402 class Linear(nn.Module):