diff --git a/bitblas/gpu/matmul.py b/bitblas/gpu/matmul.py index ad450eff2..4d21cdb46 100644 --- a/bitblas/gpu/matmul.py +++ b/bitblas/gpu/matmul.py @@ -4,13 +4,14 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" from dataclasses import dataclass -from typing import Optional - +from typing import Optional, List +from contextlib import suppress from tvm import tir from tvm.target import Target from tvm.tir.stmt import ForKind from ..base import analysis +from ..base.analysis import get_coalesced_veclen from .base import GPUScheduleRule from . import utils from .matmul_analysis import ( @@ -19,6 +20,7 @@ get_in_out_dtypes, get_index_map, normalize_to_matmul, + _collect_producers, get_reduction_blocks, ) from .matmul_mma import MatmulTensorizationMMA @@ -242,6 +244,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring func: tir.PrimFunc, config, ) -> tir.Schedule: + if "dequantize_info" in func.attrs: + return self.sch_dequantize_in_register_with_config(func, config) sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) @@ -278,7 +282,7 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring # Step 2. Get schedule config. block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) - thread_row_tiles = config.thread[1] // (config.step[0] * 2) + thread_row_tiles = config.thread[0] // (config.step[0] * 2) thread_col_tiles = config.thread[1] // (config.step[1] * 2) vthread_row_tiles = (config.step[0] * 2) # expand vtrhead to avoid load band conflict vthread_col_tiles = (config.step[1] * 2) # expand vtrhead to avoid load band conflict @@ -370,3 +374,296 @@ def is_trivial_load(block): sch.decompose_reduction(main_block, ko) return sch + + def sch_dequantize_in_register_with_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + ''' + For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. + quantized weight + | + V + dequantized in register + | + V + save into shared memory + | + V + compute + ''' + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + # in some case conv template will use this rule, but the tile config is not + # analyzed by matmul expr. + if len(config.block) != 2: + logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.") + return None + + # Check Dequantize Info + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" + + main_block = reduction_blocks[0] + + block_stmt = sch.get(main_block) + + # dequant must be 'n' 't' 'n' layout for fast decoding. + index_maps = get_index_map(block_stmt, ["n", "t", "n"]) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Get schedule config. + block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) + block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) + thread_row_tiles = config.thread[0] // (config.step[0]) + thread_col_tiles = config.thread[1] // (config.step[1]) + vthread_row_tiles = (config.step[0]) # expand vtrhead to avoid load band conflict + vthread_col_tiles = (config.step[1]) # expand vtrhead to avoid load band conflict + chunk = config.rstep[0] + shared_scope = config.shared_scope + + num_ty = block_row_warps + num_tx = block_col_warps + + # Step 3. Schedule matmul + BM = block_row_warps * vthread_row_tiles * thread_row_tiles + BN = block_col_warps * vthread_col_tiles * thread_col_tiles + + # TODO(lei): this is a hack. + def find_valid_number(k, chunk, magic=16): + # Start with the largest possible number smaller than chunk that is divisible by 16 + num = (chunk // magic) * magic + # Iterate downwards to find a number divisible by both 16 and k + while num > 0: + if k % num == 0: + return num + num -= magic + + return None # If no such number is found + + K = func.buffer_map[func.params[0]].shape[-1] + BK = find_valid_number(K, chunk) + + sch.pad_einsum( + main_block, + [1, BM, BN, BK], + ) + batch, y, x, k = sch.get_loops(main_block) + by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) + bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) + ko, ki = sch.split(k, factors=[None, BK]) + sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) + by = sch.fuse(batch, by) + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + sch.bind(vy, "vthread.y") + sch.bind(vx, "vthread.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + l2g = sch.cache_write(main_block, 0, "local") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + + def _cooperative_fetch(index, vec_len): + block = sch.cache_read(main_block, index, "shared") + num_loops = len(sch.get_loops(block)) + block_local = sch.cache_read(main_block, index, "local") + sch.compute_at(block_local, ki, preserve_unit_loops=True) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + _, ty, tx, vec = sch.split( + sch.fuse(*loops), + factors=[None, block_row_warps, block_col_warps, vec_len], + ) + + auto_inline_producers(sch, block) + + def is_trivial_load(block): + # avoid vectorize under global[v2, v1]] shared[v1, v2] case + reads = sch.get(block).reads + writes = sch.get(block).writes + if len(reads) != 1 or len(writes) != 1: + return False + return all( + read.region[-1] == write.region[-1] for read, write in zip(reads, writes)) + + if is_trivial_load(block): + sch.vectorize(vec) + + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + _, vec = sch.split( + sch.fuse(*sch.get_loops(block_local)[-2:]), + [None, vec_len // prod(config.step)], + ) + sch.vectorize(vec) + + return block + + for i, input_region in enumerate(sch.get(main_block).reads[:1]): + _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") + if _buffer_name not in config.cached_tensors: + logger.warning( + f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." + ) + continue + + # otherwise cooperative fetch in shared memory. + vectorize = config.vectorize.get(_buffer_name, 1) + + _cooperative_fetch(i, vec_len=vectorize) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, ko, preserve_unit_loops=True) + + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = _collect_producers(sch, weight_dequantize_block) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + union_len = (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + + _, B_shared_ty, B_shared_tx = sch.split(B_shared_fused, factors=[None, num_ty, num_tx]) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, bx) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, num_tx]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(main_block, 1) + + auto_inline_consumer_chain(sch, l2g) + + _, vec = sch.split( + sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)]) + sch.vectorize(vec) + + sch.decompose_reduction(main_block, ko) + # plan import source + if len(import_source) > 0: + sch.annotate( + ty, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 4a0ef532f..9ddc4500b 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -561,10 +561,15 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, + target: Target) -> Union[bool, Dict]: tags: Dict[str, Union[List[int], int]] = {} block_stmt = sch.get(block) + # Nvidia Only Support Tensor Core for + # devices greater than 70. + if check_sm_version(target.arch) < 70: + return False # analysis tensorcore axis # todo(lei): maybe we can remove this in the future (write_buffer_region,) = block_stmt.writes @@ -612,6 +617,11 @@ def check_last_trait(region: List[Range]): in_dtype, out_dtype = get_in_out_dtypes(block_stmt) intrin_info["in_dtype"] = in_dtype intrin_info["out_dtype"] = out_dtype + + if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32": + # INT32 Accum TensorCore only supports SM Version > 32. + return False + # if the last dimension is reduce axis, the B is transposed intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) if func.attrs is not None and "input_transform_kind" in func.attrs: @@ -666,6 +676,7 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) + # 16 for 16 bits tensor core while 32 for 8bits tensorcore. minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index 6bd787535..bc1012bb3 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -7,6 +7,12 @@ from tokenization_bitnet import BitnetTokenizer from transformers import GenerationConfig import time +import transformers + +print(f"transformers version is {transformers.__version__}") + +# version must be lower than or equal to 4.40 +assert transformers.__version__ <= "4.40.0" torch.set_grad_enabled(False) bitblas.set_log_level("INFO")