diff --git a/3rdparty/tvm b/3rdparty/tvm index 192ed5484..3c6317a1e 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 192ed5484db254311c8169e0291ee5ab78eaf186 +Subproject commit 3c6317a1ea614b7277ffe0b4ede18b4652afad1c diff --git a/README.md b/README.md index 7eeda4ab0..43f1d92d7 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,14 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and | **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS Support** | **Tested Platform** | |:-----------:|:-----------:|:---------------:|:--------------------:|:-------------------:|:----------------------------------------------------:| +| BF16 | BF16 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | INT8 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT4/INT4 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT2/INT2 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT1 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | NF4 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | | FP16 | FP16 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | diff --git a/VERSION b/VERSION index 419bd5f01..511aa8188 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1.dev13 \ No newline at end of file +0.0.1.dev14 \ No newline at end of file diff --git a/bitblas/__init__.py b/bitblas/__init__.py index a1bc95f39..58faec4ce 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -112,4 +112,4 @@ def new_func(*args, **kwargs): return decorator -__version__ = "0.0.1.dev13" +__version__ = "0.0.1.dev14" diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 92c642b68..d2168c850 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -20,6 +20,8 @@ import itertools from tvm.ir.supply import GlobalVarSupply from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils.tensor_adapter import ( + np_float2np_bf16,) import logging logger = logging.getLogger(__name__) @@ -149,17 +151,21 @@ def map_numpy_type(intype): numpy_dtype = map_numpy_type(arg.dtype) if distribution == "uniform": - profile_tensors.append( - tvm.nd.array( - np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), - device=device, - )) + data_np = np.random.rand(*[var_wrapper(i) for i in arg.shape]) + if arg.dtype == "bfloat16": + profile_tensors.append( + tvm.nd.empty(data_np.shape, device=device, dtype=arg.dtype).copyfrom( + np_float2np_bf16(data_np.astype(np.float32)))) + else: + profile_tensors.append(tvm.nd.array(data_np.astype(numpy_dtype), device=device)) elif distribution == "onefill": - profile_tensors.append( - tvm.nd.array( - np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), - device=device, - )) + data_np = np.ones(*[var_wrapper(i) for i in arg.shape]) + if arg.dtype == "bfloat16": + profile_tensors.append( + tvm.nd.empty(data_np.shape, device=device, + dtype=arg.dtype).copyfrom(np_float2np_bf16(data_np))) + else: + profile_tensors.append(tvm.nd.array(data_np.astype(numpy_dtype), device=device)) else: raise ValueError("Not supported distribution: ", distribution) return profile_tensors diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index 0bedf70ed..59d63298b 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -18,7 +18,7 @@ class TIRCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", "float16": "half", - "bfloat16": "__nv_bfloat162", + "bfloat16": "__nv_bfloat16", "e4m3_float8": "__nv_fp8_e4m3", "e5m2_float8": "__nv_fp8_e5m2", "float64": "double", diff --git a/bitblas/gpu/gemv_dequantize.py b/bitblas/gpu/gemv_dequantize.py index 32c8cfbd1..9d56e1233 100644 --- a/bitblas/gpu/gemv_dequantize.py +++ b/bitblas/gpu/gemv_dequantize.py @@ -55,7 +55,8 @@ def check_weight_decode_info(weight_decode_info): conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + conditions.append( + weight_decode_info["target_format"] in ["float16", "bfloat16", "int8"]) return all(conditions) if not check_weight_decode_info(weight_decode_info): @@ -223,7 +224,8 @@ def check_weight_decode_info(weight_decode_info): conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + conditions.append( + weight_decode_info["target_format"] in ["float16", "bfloat16", "int8"]) return all(conditions) if not check_weight_decode_info(weight_decode_info): diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index f078e7f47..466466ed9 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1626,7 +1626,8 @@ def get_lop3_intrin_group( Dict[str, str] A dictionary mapping the names of the intrinsics to their corresponding implementations. """ - assert out_dtype in ["float16", "int8"] + assert out_dtype in ["float16", + "int8"], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8'.") dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"} target_dtype = dtype_mapping[out_dtype] diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 36cba1969..16f33664a 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -624,7 +624,7 @@ def check_last_trait(region: List[Range]): # When the func is a dequantize like ops, we should consider the M require_block_reduce = False # And we only support float16 for now - if hasattr(func.attrs, "dequantize_info") and in_dtype == "float16": + if (hasattr(func.attrs, "dequantize_info") and in_dtype in ["bfloat16", "float16"]): for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] @@ -690,12 +690,14 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde ) assert dtype in [ + "bfloat16", "float16", "int8", "e4m3_float8", "e5m2_float8", - ], "Only support float16, int8, e4m3_float8, e5m2_float8" - if dtype == "float16": + ], "Only support bfloat16, float16, int8, e4m3_float8, e5m2_float8" + # TODO(lei): actually should analyze based on bits instead of dtype + if dtype in ["bfloat16", "float16"]: ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: @@ -723,7 +725,7 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): local_id = kernel_j % 16 return ldmatrix_layout(thread_id, local_id) - if dtype == "float16": + if dtype in ["bfloat16", "float16"]: ldmatrix_index_map = ( ldmatrix_trans_permutation_16x16_32x8_16x16 if trans else ldmatrix_permutation_16x16_32x8_16x16) @@ -732,7 +734,7 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype) # TODO(lei): index_dtype should be analyzed from the schedule - row, col = [16, 16] if dtype == "float16" else [16, 32] + row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32] inversed_index_map = ldmatrix_index_map.inverse([row, col]) return ldmatrix_index_map, inversed_index_map @@ -753,12 +755,13 @@ def shared_32x16_to_mma_32x16_layout(i, j): return thread_id, local_id assert dtype in [ + "bfloat16", "float16", "int8", "e4m3_float8", "e5m2_float8", ], "Only support float16, int8, e4m3_float8, e5m2_float8" - if dtype == "float16": + if dtype in ["bfloat16", "float16"]: stage3_layout = shared_32x8_to_mma_32x8_layout elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: stage3_layout = shared_32x16_to_mma_32x16_layout @@ -782,14 +785,14 @@ def ladder_stage3_permutation_16x32_32x16_32x16_16x32(kernel_i, kernel_j): new_kernel_j = (new_thread_id * 16 + new_local_id) % 32 return new_kernel_i, new_kernel_j - if dtype == "float16": + if dtype in ["bfloat16", "float16"]: stage3_index_map = ladder_stage3_permutation_16x16_32x8_32x8_16x16 else: stage3_index_map = ladder_stage3_permutation_16x32_32x16_32x16_16x32 stage3_index_map = IndexMap.from_func(stage3_index_map, index_dtype=index_dtype) # TODO(lei): index_dtype should be analyzed from the schedule - row, col = [16, 16] if dtype == "float16" else [16, 32] + row, col = [16, 16] if dtype in ["bfloat16", "float16"] else [16, 32] inversed_index_map = stage3_index_map.inverse([row, col]) return stage3_index_map, inversed_index_map diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index e3a4fef04..b5c327190 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1237,7 +1237,8 @@ def check_weight_decode_info(weight_decode_info): conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + conditions.append( + weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"]) return all(conditions) assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index 428814e3c..242589c7b 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -39,6 +39,25 @@ def unpack_qzeros(qzeros, bits): return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) +# For gptqv2 from gptqmodel +def unpack_qzeros_v2(qzeros, bits): + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) + + # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 + # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. + return torch.bitwise_and(unpacked_zeros, 2**bits - 1) + + def unpack_qweight(qweight, bits): qweight = qweight.view(torch.int8) elems_per_int8 = 8 // bits @@ -318,6 +337,31 @@ def repack_from_gptq(self, gptq_module): if self.bias is not None: self.bias = gptq_module.bias.data.to(torch.float16).contiguous() + def repack_from_gptq_v2(self, gptq_module): + # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. + qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + intweight = unpack_qweight(qweight, self.bits).contiguous() + if self.bitblas_matmul.weight_transform is not None: + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda() + self.qweight = qweight + # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. + scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) + self.scales = scales + # qzeros should be dequantized to int zeros. + intzeros = unpack_qzeros_v2(gptq_module.qzeros, self.bits).T.contiguous() + if self.bitblas_matmul.config.zeros_mode == "original": + self.zeros = intzeros.to(torch.float16).contiguous() + elif self.bitblas_matmul.config.zeros_mode == "rescale": + self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :] + elif self.bitblas_matmul.config.zeros_mode == "quantized": + self.zeros = ( + torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to( + self.qweight.device).to(self.zeros.dtype).contiguous()) + else: + raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}") + if self.bias is not None: + self.bias = gptq_module.bias.data.to(torch.float16).contiguous() + @property def consistent(self): return self.is_consitent diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index dea4042e1..2945996df 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -30,6 +30,7 @@ ("float64", "float64"), ("float32", "float32"), ("float16", "float16"), + ("bfloat16", "bfloat16"), ("int8", "int8"), ("e4m3_float8", "e4m3_float8"), ("e4m3_float8", "e5m2_float8"), @@ -140,7 +141,7 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], # TODO(lei): This is a limitation arose by pytorch and llvm # Should be removed in the future. - if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: + if self.A_dtype in ["e4m3_float8", "e5m2_float8", "bfloat16"]: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) object.__setattr__(self, "propagate_b", TransformKind.NonTransform) @@ -159,6 +160,9 @@ def is_not_fast_decoding_supported(): # if the w_dtype is int4/uint4 and the a_dtype is int8 # we do not require fast decoding conditions.append(self.W_dtype in ["int4", "uint4"] and self.A_dtype in ["int8"]) + # do not support bfloat16 currently + # TODO(lei): should implement to improve the performance + conditions.append(self.A_dtype == "bfloat16") return any(conditions) if fast_decoding is not None: @@ -214,6 +218,7 @@ def __post_init__(self): if self.A_dtype == self.W_dtype and self.W_dtype in [ "float16", + "bfloat16", "int8", "e4m3_float8", "e5m2_float8", @@ -228,6 +233,7 @@ class Matmul(Operator): "float64": ("fp", 64), "float32": ("fp", 32), "float16": ("fp", 16), + "bfloat16": ("bf", 16), "int32": ("int", 32), "uint32": ("uint", 32), "int16": ("int", 16), @@ -260,8 +266,13 @@ def __init__( if target is None: target = auto_detect_nvidia_target() logger.info(f"Auto detected target: {target}") + assert (config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" + + assert (config.W_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported weight dtype {config.W_dtype}" + source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] self.source_format = source_format diff --git a/bitblas/ops/ladder_permutate/ladder_permutate_impl.py b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py index 44d368a1a..5d6d8f981 100644 --- a/bitblas/ops/ladder_permutate/ladder_permutate_impl.py +++ b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py @@ -12,7 +12,7 @@ def select_implementation( M: int, N: int, - datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", + datatype: Literal["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", dequantize_bits: int = -1, storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", propagate_kind: Literal["A", "B"] = "B", diff --git a/bitblas/utils/tensor_adapter.py b/bitblas/utils/tensor_adapter.py index d4d052dbb..5dbcb1663 100644 --- a/bitblas/utils/tensor_adapter.py +++ b/bitblas/utils/tensor_adapter.py @@ -91,11 +91,11 @@ def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): else: raise RuntimeError("Not supported type: ", type(tensor)) + def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): # It additionally needs the ctypes type as torch type def as_tensor(address, shape, elems_inbytes, torch_type): - arr = (ctypes.c_int8 * elems_inbytes).from_address( - address) + arr = (ctypes.c_int8 * elems_inbytes).from_address(address) return torch.frombuffer(arr, dtype=torch_type).view(*shape) if isinstance(tensor, tvm.nd.NDArray): @@ -110,11 +110,11 @@ def as_tensor(address, shape, elems_inbytes, torch_type): else: raise RuntimeError("Not supported type: ", type(tensor)) + def lazy_torch_to_tvm_tensor(tensor): # It additionally needs the ctypes type as torch type def as_tensor(address, shape, elems_inbytes, numpy_type): - arr = (ctypes.c_int8 * elems_inbytes).from_address( - address) + arr = (ctypes.c_int8 * elems_inbytes).from_address(address) return np.frombuffer(arr, dtype=numpy_type).reshape(shape) if isinstance(tensor, torch.Tensor): @@ -122,9 +122,24 @@ def as_tensor(address, shape, elems_inbytes, numpy_type): shape = tensor.shape torch_dtype = tensor.dtype numpy_dtype = str(torch_dtype).replace("torch.", "") - num_elems_inbytes = prod(shape) * tensor.itemsize + num_elems_inbytes = prod(shape) * tensor.itemsize np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype) tvm_tensor = tvm.nd.array(np_tensor) return tvm_tensor else: raise RuntimeError("Not supported type: ", type(tensor)) + + +def np_float2np_bf16(arr): + """Convert a numpy array of float to a numpy array + of bf16 in uint16""" + orig = arr.view("