From d8884e6f6a294fc8f1a325665d86a07603d43864 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:54:26 +0000 Subject: [PATCH 01/12] Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability --- bitblas/ops/impl/base.py | 16 +++ bitblas/ops/impl/batch_matmul_impl.py | 166 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 63 deletions(-) create mode 100644 bitblas/ops/impl/base.py diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py new file mode 100644 index 000000000..6d510f7da --- /dev/null +++ b/bitblas/ops/impl/base.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + +# TODO: Refactor all the tir script implementations to use this base class +# Abstract base class for TIR script emitters +class TIRScriptEmitter(ABC): + @abstractmethod + def emit(self): + raise NotImplementedError + +# Abstract base class for TIR script selectors +class TIRScriptSelector(ABC): + @abstractmethod + def select(self): + raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 09b536afa..75449ea4b 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -4,63 +4,117 @@ from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind +from .base import TIRScriptEmitter, TIRScriptSelector +from bitblas import tvm +from tvm import te +from bitblas.ops.operator import TransformKind +class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( + self, + batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + self.batch = batch + self.M = self._validate_dimension(M, "M") + self.N = self._validate_dimension(N, "N") + self.K = self._validate_dimension(K, "K") + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.with_bias = with_bias + self.layout = layout + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim -def matmul_nt( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + def _validate_layout(self): + if self.layout not in ["nn", "nt"]: + raise ValueError(f"Unsupported layout: {self.layout}") + if self.layout == "nn": + raise ValueError("Currently only support layout=nt") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D + def _create_placeholders(self): + A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) + B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + return A, B, Bias - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E + def _compute_matmul(self, A, B): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.batch, self.M, self.N), + lambda b, i, j: te.sum( + A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + def _apply_bias(self, C, Bias): + if self.with_bias: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return C - func = te.create_prim_func(args) + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return tensor - return tvm.IRModule.from_expr(func) + def emit(self): + A, B, Bias = self._create_placeholders() + C = self._compute_matmul(A, B) + last_output = self._convert_dtype(C) + if self.with_bias: + last_output = self._apply_bias(last_output, Bias) + args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output] + func = te.create_prim_func(args) + return tvm.IRModule.from_expr(func) -def matmul( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) +class BatchMatMulSelector(TIRScriptSelector): + def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + self.propagate_a = propagate_a + self.propagate_b = propagate_b + + def select( + self, + batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + if layout == "nn": + if self.propagate_a or self.propagate_b: + raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + elif layout == "nt": + if self.propagate_a and self.propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif self.propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif self.propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + else: + raise ValueError(f"Unsupported layout: {layout}") def select_implementation( Batch=1, @@ -75,19 +129,5 @@ def select_implementation( propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform, ): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") + selector = BatchMatMulSelector(propagate_a, propagate_b) + return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) From fc84173f22d2f4867a8e6413117b5cd8e830ab27 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:57:43 +0000 Subject: [PATCH 02/12] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- bitblas/ops/impl/base.py | 4 ++++ bitblas/ops/impl/batch_matmul_impl.py | 33 ++++++++++++++++++--------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index a254dc7fb..8a9bbd2a5 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py index 6d510f7da..4a67987be 100644 --- a/bitblas/ops/impl/base.py +++ b/bitblas/ops/impl/base.py @@ -2,15 +2,19 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod + # TODO: Refactor all the tir script implementations to use this base class # Abstract base class for TIR script emitters class TIRScriptEmitter(ABC): + @abstractmethod def emit(self): raise NotImplementedError + # Abstract base class for TIR script selectors class TIRScriptSelector(ABC): + @abstractmethod def select(self): raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 75449ea4b..3904f36e6 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -5,11 +5,10 @@ from tvm import te from bitblas.ops.operator import TransformKind from .base import TIRScriptEmitter, TIRScriptSelector -from bitblas import tvm -from tvm import te -from bitblas.ops.operator import TransformKind + class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( self, batch, @@ -32,7 +31,7 @@ def __init__( self.with_bias = with_bias self.layout = layout self._validate_layout() - + @staticmethod def _validate_dimension(dim, name): if not isinstance(dim, int): @@ -48,7 +47,8 @@ def _validate_layout(self): def _create_placeholders(self): A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + Bias = te.placeholder( + (self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None return A, B, Bias def _compute_matmul(self, A, B): @@ -63,12 +63,16 @@ def _compute_matmul(self, A, B): def _apply_bias(self, C, Bias): if self.with_bias: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: C[b, i, j] + Bias[j], + name="E") return C def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), + name="D") return tensor def emit(self): @@ -84,7 +88,10 @@ def emit(self): class BatchMatMulSelector(TIRScriptSelector): - def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + + def __init__(self, + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform): self.propagate_a = propagate_a self.propagate_b = propagate_b @@ -102,8 +109,10 @@ def select( ): if layout == "nn": if self.propagate_a or self.propagate_b: - raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, + layout).emit() elif layout == "nt": if self.propagate_a and self.propagate_b: raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") @@ -112,10 +121,12 @@ def select( elif self.propagate_b: raise ValueError("Currently only support propagate_b=False for layout=nt") else: - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, + with_bias, layout).emit() else: raise ValueError(f"Unsupported layout: {layout}") + def select_implementation( Batch=1, M=None, From 02f64de6cf2d338c092dcf29ec55b69804fda892 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:58:06 +0000 Subject: [PATCH 03/12] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index 8a9bbd2a5..67e49b2ae 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 From 397eee6141599e84b509594bb99a0531e409c266 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 16:25:47 +0000 Subject: [PATCH 04/12] disable failure email for ci --- .github/workflows/ci.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ceb69fcc7..1fbdf19dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,4 +64,13 @@ jobs: run: | source bitblas_ci/bin/activate cd testing/python - python -m pytest \ No newline at end of file + python -m pytest + + # Control notifications + notify: + runs-on: self-hosted + needs: [format-check, build-test] + if: failure() + steps: + - name: Notification + run: echo "Jobs failed, but no email will be sent." From 20f6ad1e7ca4e6e1ca9e13ad7c1bbc8c430a8e51 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:23:50 +0000 Subject: [PATCH 05/12] remove email notifications. --- .github/workflows/ci.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbdf19dd..511b95833 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,12 +65,3 @@ jobs: source bitblas_ci/bin/activate cd testing/python python -m pytest - - # Control notifications - notify: - runs-on: self-hosted - needs: [format-check, build-test] - if: failure() - steps: - - name: Notification - run: echo "Jobs failed, but no email will be sent." From b93c39431c803e22b12f71b555939785da36b96a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:25:05 +0000 Subject: [PATCH 06/12] move relax pass from testing to mlc_llm --- .../mlc_llm}/test_weight_only_transform.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {testing/python/transform => integration/mlc_llm}/test_weight_only_transform.py (100%) diff --git a/testing/python/transform/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py similarity index 100% rename from testing/python/transform/test_weight_only_transform.py rename to integration/mlc_llm/test_weight_only_transform.py From 257693a7c3cb3083aac144182f58d38bfe3bcfdd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:01 +0000 Subject: [PATCH 07/12] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 224 ++++++++++++++---- .../operators/test_tir_script_emitter.py | 52 +++- 2 files changed, 216 insertions(+), 60 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ed6b3404..e69e8fcfb 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,8 +15,10 @@ _tir_packed_to_unsigned_convert_with_zeros, ) + # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: + def __init__( self, M, @@ -52,8 +54,8 @@ def __init__( self.fast_decoding = fast_decoding self.with_bias = with_bias self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) self._validate_bit() self._validate_layout() @@ -69,62 +71,169 @@ def _validate_bit(self): raise ValueError(f"Unsupported bit: {self.bit}") def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), + name="B", + dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), name="QZeros", dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place def decode(n, k): + w = None if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, QZeros[k, n // n_float_per_elem], n % n_float_per_elem, dtype=self.storage_dtype, ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, qzeros_dequantize, - dtype=self.in_dtype, + dtype=in_dtype, ) elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype="int32", @@ -132,7 +241,9 @@ def decode(n, k): w = LUT[index] else: raise ValueError(f"Unsupported source_format: {self.source_format}") - + + assert w is not None, "w is None" + group_size = self.group_size zeros_mode = self.zeros_mode @@ -167,7 +278,9 @@ def _compute_matmul(self, A, B_decode): def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") return tensor def _apply_bias(self, tensor, Bias): @@ -176,9 +289,12 @@ def _apply_bias(self, tensor, Bias): return tensor def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) D = self._convert_dtype(C) last_output = self._apply_bias(D, Bias) @@ -212,8 +328,13 @@ def emit(self): } }, ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) return tvm.IRModule.from_expr(func) + def matmul_nt_dequantize_b( M, N, @@ -335,9 +456,12 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -517,9 +641,11 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -715,9 +841,11 @@ def decode_func(n, k): ), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index cec56b473..fcfa7d9af 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -1,18 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.ops.impl.matmul_dequantize_impl import ( - MatMulNTDequantizeEmitter, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) from bitblas import tvm import logging from bitblas import set_log_level set_log_level(logging.DEBUG) -def compare_tir_scripts_and_emitter( + +def check_eual_ref_scripts_with_emitter( M, N, K, @@ -28,8 +23,26 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a, + propagate_b, ): - tir_script_func = matmul_nt_dequantize_b( + from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, + ) + func = None + if propagate_a and propagate_b: + func = matmul_nt_dequantize_b_propagate_a_propagate_b + elif propagate_b: + func = matmul_nt_dequantize_b_propagate_b + else: + func = matmul_nt_dequantize_b + + assert func is not None, "No function found for the given configuration" + + ref_func = func( M, N, K, @@ -46,8 +59,8 @@ def compare_tir_scripts_and_emitter( with_bias, zeros_mode, ) - - emitter_func = MatMulNTDequantizeEmitter( + + emit_func = MatMulNTDequantizeEmitter( M, N, K, @@ -63,6 +76,21 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, ).emit() - - tvm.ir.assert_structural_equal(tir_script_func, emitter_func) + + tvm.ir.assert_structural_equal(ref_func, emit_func) + + +def test_check_eual_ref_scripts_with_emitter(): + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + +if __name__ == "__main__": + test_check_eual_ref_scripts_with_emitter() From 9bb7f49a968d4c71dbbc12121b4b7cb8258b2136 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:15 +0000 Subject: [PATCH 08/12] Lint Fix --- bitblas/ops/impl/matmul_dequantize_impl.py | 13 +++++---- .../operators/test_tir_script_emitter.py | 29 ++++++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index e69e8fcfb..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -73,7 +73,7 @@ def _validate_bit(self): def _validate_layout(self): # TODO: extend the dequantize operators into General Layout pass - + def _legalize_group_size(self): if self.group_size == -1: self.group_size = self.K @@ -96,18 +96,19 @@ def _create_placeholders(self): l, r = 16, 32 # noqa: E741 A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * bit), - name="B", - dtype=storage_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) if self.propagate_a: A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) if self.propagate_b: target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) qr = r * bit // storage_nbit - B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index fcfa7d9af..b2c7a8d4f 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -84,13 +84,28 @@ def check_eual_ref_scripts_with_emitter( def test_check_eual_ref_scripts_with_emitter(): - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "nf", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "nf", True, False, -1, False, False, "original", + False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, + "int8", "uint", True, False, -1, False, False, "original", + True, True) + if __name__ == "__main__": test_check_eual_ref_scripts_with_emitter() From 93eb5a5fe4e3eb6242675dd5706358c4121f1672 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:53:50 +0000 Subject: [PATCH 09/12] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 198 --------------------- 1 file changed, 198 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ef14100d..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,204 +15,6 @@ _tir_packed_to_unsigned_convert_with_zeros, ) -# TODO: The following code should be refactored. -class MatMulNTDequantizeEmitter: - def __init__( - self, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, - ): - self.M = self._validate_dimension(M, "M") - self.N = N - self.K = K - self.in_dtype = in_dtype - self.out_dtype = out_dtype - self.accum_dtype = accum_dtype - self.bit = bit - self.storage_dtype = storage_dtype - self.source_format = source_format - self.with_scaling = with_scaling - self.with_zeros = with_zeros - self.group_size = group_size if group_size != -1 else K - self.fast_decoding = fast_decoding - self.with_bias = with_bias - self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b - - self._validate_bit() - self._validate_layout() - - @staticmethod - def _validate_dimension(dim, name): - if not isinstance(dim, int): - return tvm.te.var(name.lower()) - return dim - - def _validate_bit(self): - if self.bit not in [1, 2, 4, 8]: - raise ValueError(f"Unsupported bit: {self.bit}") - - def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") - - def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), - name="QZeros", - dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem - - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None - def decode(n, k): - if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - QZeros[k, n // n_float_per_elem], - n % n_float_per_elem, - dtype=self.storage_dtype, - ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - qzeros_dequantize, - dtype=self.in_dtype, - ) - elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) - elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", - ) - w = LUT[index] - else: - raise ValueError(f"Unsupported source_format: {self.source_format}") - - group_size = self.group_size - zeros_mode = self.zeros_mode - - if not self.with_scaling: - return w - - if not self.with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - elif zeros_mode == "quantized": - w = w * Scale[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - return te.compute((self.N, self.K), decode, name="B_decode") - - def _compute_matmul(self, A, B_decode): - k = te.reduce_axis((0, self.K), name="k") - C = te.compute( - (self.M, self.N), - lambda i, j: te.sum( - A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), - name="C", - ) - return C - - def _convert_dtype(self, tensor): - if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") - return tensor - - def _apply_bias(self, tensor, Bias): - if self.with_bias: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") - return tensor - - def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) - D = self._convert_dtype(C) - last_output = self._apply_bias(D, Bias) - - args = [A, B] - if self.source_format == "nf": - args.append(LUT) - if self.with_scaling: - args.append(Scale) - if self.with_zeros: - args.append(QZeros if self.zeros_mode == "quantized" else Zeros) - if self.with_bias: - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": self.fast_decoding, - "source_format": { - "bits": self.bit, - "format": self.source_format, - }, - "storage_dtype": self.storage_dtype, - "target_format": self.in_dtype, - "with_zeros": self.with_zeros, - "zeros_mode": self.zeros_mode, - "with_scaling": self.with_scaling, - "group_size": self.group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: From 9efa5ab5e6a41c86da7bbc486304f0f2f340fec9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 24 Aug 2024 04:08:37 +0000 Subject: [PATCH 10/12] Kernel Name --- bitblas/base/utils.py | 33 +- bitblas/builder/wrapper/tir.py | 64 ++- bitblas/cache/operator.py | 4 +- bitblas/ops/general_matmul/__init__.py | 85 +++- bitblas/ops/ladder_permutate/__init__.py | 2 +- bitblas/ops/operator.py | 108 ++++- bitblas/wrapper/__init__.py | 4 - bitblas/wrapper/general.py | 519 ----------------------- 8 files changed, 228 insertions(+), 591 deletions(-) delete mode 100644 bitblas/wrapper/__init__.py delete mode 100644 bitblas/wrapper/general.py diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index d2168c850..90fab86d0 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -6,7 +6,7 @@ from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np -from typing import List, Tuple, Optional, Dict, Union, Literal +from typing import List, Tuple, Optional, Dict, Union, Literal, Callable from tvm import tir, IRModule from tvm.runtime import Module from tvm.tir import Schedule @@ -455,13 +455,13 @@ def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[ def create_dispatch_mod(g_var: str, original_func: tir.PrimFunc, - specialized_funcs: List[tir.PrimFunc]) -> IRModule: + specialized_funcs: List[tir.PrimFunc], function_symbols) -> IRModule: dispatch_mod: IRModule = tvm.IRModule() g_var_supply = GlobalVarSupply(dispatch_mod) refactored_funcs = [] - for func in specialized_funcs: + for f_var, func in zip(function_symbols, specialized_funcs): params, buffers_to_declare = collect_buffers_to_declare(func) - global_symbol, device_func = refactor_specialized_func(g_var, func, params, + global_symbol, device_func = refactor_specialized_func(f_var, func, params, buffers_to_declare) global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) dispatch_mod[global_symbol] = device_func @@ -478,6 +478,7 @@ def fast_tune_with_dynamic_range( parallel_build: bool = True, global_symbol: Optional[str] = None, dynamic_range: Optional[Dict[str, List[int]]] = None, + kernel_name_generator: Optional[Callable] = None, ) -> IRModule: if dynamic_range is None: dynamic_range = {} @@ -517,12 +518,30 @@ def fast_tune_with_dynamic_range( # Convert the Cartesian product to a list of dictionaries specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list] + function_symbols: List[str] = [] specilized_tuned_funcs: List[tir.PrimFunc] = [] for item in specialize_items: func = func.with_attr("opt_shapes", item) _, best = fast_tune(func, target, topk, parallel_build) if best is None: return None - specilized_tuned_funcs.append(best.sch.mod["main"]) - - return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs) + specialized_func = best.sch.mod["main"] + function_symbol = global_symbol + if kernel_name_generator is not None: + scheduled_mod = best.sch.mod + best_hint = best.config + assert len(scheduled_mod.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + assert "main" in scheduled_mod, ( + "The optimized module should have a function named 'main' for default schedule.") + default_kernal_name = kernel_name_generator.generate(best_hint) + specialized_func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + function_symbol = default_kernal_name + + function_symbols.append(function_symbol) + specilized_tuned_funcs.append(specialized_func) + + assert global_symbol is not None, "The global_symbol should not be None" + assert len(function_symbols) == len(specilized_tuned_funcs), ( + "The length of global_symbols should be equal to the length of specilized_tuned_funcs") + return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs, function_symbols) diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index 59d63298b..5d5763861 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -13,6 +13,22 @@ logger = logging.getLogger(__name__) +PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """ + cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {}); +""" + +PREDEF_INIT_FUNC = """ +extern "C" void init() {{ + {} +}} +""" + +PREDEF_HOST_FUNC = """ +extern "C" void call({}) {{ +{} +}} +""" + class TIRCUDASourceWrapper(object): _TYPE_MAP = { @@ -77,16 +93,11 @@ def get_cuda_init_func(self): call_str = """""" # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call if self.dynamic_smem_buf is not None: - call_str = """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(self.function_name, self.dynamic_smem_buf) + call_str = ( + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, + self.dynamic_smem_buf)) # Format the initialization function using the call_str - init_funcs = """ - extern "C" void init() {{ - {} - }} - """.format(call_str) + init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs def update_lib_code(self, code: str): @@ -162,11 +173,7 @@ def legalize_c(p): call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, smem_str, call_args) # Create the host function wrapper for the CUDA kernel - host_func = """ - extern "C" void call({}) {{ - {} - }} - """.format(def_args, call_str) + host_func = PREDEF_HOST_FUNC.format(def_args, call_str) # Combine the source, initialization function, and host function to form the complete library code lib_code = self.source + init_func + host_func return lib_code @@ -188,16 +195,10 @@ def get_cuda_init_func(self): for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): if dynamic_smem_buf is not None: # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(function_name, dynamic_smem_buf) + call_str += ( + PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf)) # Define the init function that will set the attributes for each kernel - init_funcs = """ -extern "C" void init() {{ - {} -}} - """.format(call_str) + init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs def create_dispatch_func(self, code, function_informations): @@ -278,8 +279,8 @@ def legalize_c(p): (symbolic,) = list(dynamic_symbolic_set) range_str = opt_shapes[symbolic] if last_range == 0: - call_str = "if ({} == 0) return; \n".format(symbolic,) - call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + call_str = "\tif ({} == 0) return; \n".format(symbolic,) + call_str += "\tif ({} <= {}) {{\n\t\t{}<<<{}, {}, {}, stream>>>({}); \n\t}}\n".format( symbolic, range_str, function_name, @@ -289,7 +290,7 @@ def legalize_c(p): call_args, ) else: - call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + call_str = "\telse if ({} <= {}) {{\n\t\t{}<<<{}, {}, {}, stream>>>({}); \n\t}}\n".format( symbolic, range_str, function_name, @@ -299,18 +300,13 @@ def legalize_c(p): call_args, ) if last_range == num_items - 1: - call_str += ( - "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - function_name, grid_str, block_str, smem_str, call_args)) + call_str += ("\telse {{\n\t\t{}<<<{}, {}, {}, stream>>>({}); \n\t}}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) last_range += 1 _call_str += call_str # Wrap the kernel dispatch logic in an external C function - host_func = """ -extern "C" void call({}) {{ - {} -}} - """.format(def_args, _call_str) + host_func = PREDEF_HOST_FUNC.format(def_args, _call_str) return host_func def parse_source_information(self): diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 6c5ea1ebe..597b2a34f 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -108,8 +108,8 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): # For writing optimized.py file optimized_file_path = os.path.join(config_path, "optimized.py") with open(optimized_file_path, "w") as optimized_file: - if op_inst.optimized_func is not None: - optimized_file.write(op_inst.optimized_func.script(show_meta=False)) + if op_inst.optimized_mod is not None: + optimized_file.write(op_inst.optimized_mod.script(show_meta=False)) if op_inst.libpath is not None: # copy lib name to the same directory as the artifact srcpath = op_inst.srcpath diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 2945996df..0b55ef8f7 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -6,8 +6,9 @@ from functools import reduce from enum import IntEnum from bitblas.base.arch.cuda import CUDA +from bitblas.base.roller.hint import Hint from typing import Any, Literal, Optional, Tuple, Union -from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU +from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU, BaseKernelNameGenerator from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation from .tirscript.matmul_impl import select_implementation as consistent_implementation from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 @@ -226,6 +227,85 @@ def __post_init__(self): object.__setattr__(self, "storage_dtype", self.W_dtype) +class MatmulKernelNameGenerator(BaseKernelNameGenerator): + + KERNEL_PREFIX = "matmul" + + @staticmethod + def serialize_hint(hint: Optional[Hint] = None) -> str: + if hint is None: + return "default" + else: + if hint.use_tc: + hint_prefix = "tc" + BM, BN = hint.block + WM, WN = hint.warp + BK = hint.rstep[-1] + reduce_k = hint.block_reduction_depth + pipeline_stage = hint.pipeline_stage + hint_name = f"{hint_prefix}x{BM}x{BN}x{BK}w{WM}x{WN}" + if reduce_k is not None and reduce_k > 1: + hint_name += f"xr{reduce_k}" + if pipeline_stage > 1: + hint_name += f"xp{pipeline_stage}" + return hint_name + else: + hint_prefix = "simt" + # do not annotate for simt currently + return hint_prefix + + @staticmethod + def simplify_dtype(dtype: str) -> str: + if dtype == "float32": + return "f32" + elif dtype == "float16": + return "f16" + elif dtype == "bfloat16": + return "bf16" + elif dtype.startswith("int"): + return f"i{dtype[3:]}" + elif dtype.startswith("uint"): + return f"u{dtype[4:]}" + return dtype + + def generate(self, hint=None) -> str: + config = self.config + kernel_name = self.KERNEL_PREFIX + shape_str = f"n{self.config.N}k{self.config.K}" + if isinstance(config.M, int): + shape_str = f"m{config.M}" + shape_str + + A_dtype = self.simplify_dtype(config.A_dtype) + W_dtype = self.simplify_dtype(config.W_dtype) + + precision_str = (f"A{A_dtype}W{W_dtype}") + kernel_name = "_".join([kernel_name, shape_str, precision_str]) + + # if config.with_scaling: + # kernel_name += "Scale" + + # if config.with_zeros: + # if config.zeros_mode == "original": + # kernel_name += "OriginalZeros" + # elif config.zeros_mode == "rescale": + # precision_str += "RescaleZeros" + # elif config.zeros_mode == "quantized": + # precision_str += "QuantizedZeros" + # else: + # raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}") + + # if config.propagate_a is not TransformKind.NonTransform: + # kernel_name += f"_pa{config.propagate_a.value}" + # if config.propagate_b is not TransformKind.NonTransform: + # kernel_name += f"_pb{config.propagate_b.value}" + + kernel_name = "_".join([kernel_name, self.serialize_hint(hint)]) + return kernel_name + + def is_valid_config(self, config: OperatorConfig) -> bool: + return isinstance(config, MatmulConfig) + + class Matmul(Operator): # TODO(lei): This should be improved into a general datatype class. @@ -350,6 +430,9 @@ def dispatch_tir(self, # output data type self.torch_output_dtype = getattr(torch, self.out_dtype) + def get_kernel_name_generator(self): + return MatmulKernelNameGenerator(self.config) + def _alloc_workspace(self): return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() diff --git a/bitblas/ops/ladder_permutate/__init__.py b/bitblas/ops/ladder_permutate/__init__.py index d09ee6dac..65ad06679 100644 --- a/bitblas/ops/ladder_permutate/__init__.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -38,7 +38,7 @@ def __init__( target = self.target if target.kind.name == "cuda": - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + self.optimized_mod = self.apply_default_schedule(self.prim_func_mod, target) if enable_tuning: self.hardware_aware_finetune() if not from_database: diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index f6fa4cca0..09aee625d 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -8,11 +8,12 @@ from tvm.contrib.dlpack import to_pytorch_func import bitblas import ctypes -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Tuple import numpy as np from bitblas.base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.arch import get_arch +from bitblas.base.roller.hint import Hint from bitblas.builder.wrapper import TIRWrapper from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass @@ -21,6 +22,16 @@ logger = logging.getLogger(__name__) +APPLY_SCHEDULE_FAILED_MESSAGE = ("Failed to apply default schedule for operator {} " + "With target {} and hint {}. \n" + "The error message: {} " + "Please perform hardware-aware tuning manually.") + +BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE = ("Failed to build runtime library for operator {} " + "With target {} and hint {}. \n" + "The error message: {} " + "Please perform hardware-aware tuning manually.") + class TransformKind(IntEnum): NonTransform = 0 @@ -35,6 +46,24 @@ class OperatorConfig: pass +class BaseKernelNameGenerator(ABC): + """Optional class for generating kernel names based on the config and hint""" + + def __init__(self, config: OperatorConfig): + assert self.is_valid_config(config), (f"Invalid config for {self.__class__.__name__}: " + f"{config}") + self.config = config + + @abstractmethod + def is_valid_config(self, config: OperatorConfig): + pass + + @abstractmethod + def generate(self, hint: Hint = None) -> str: + '''Generate the kernel name based on the config and hint''' + pass + + class Operator(ABC): def __init__(self, name, config: OperatorConfig, target: Target = None): @@ -44,7 +73,7 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.config = config self.target = target self.prim_func_mod = self._select_implementation() - self.optimized_func = None + self.optimized_mod = None self.rt_mod = None self.time_evaluator = None self.arch = get_arch(target) if target else None @@ -55,13 +84,20 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.num_output_args: int = ( 1 # todo(lei): should be analyzed from the prim_func. ) + self.kernel_name_generator: Optional[BaseKernelNameGenerator] = ( + self.get_kernel_name_generator()) self.lib_generator = LibraryGenerator(self.arch) self.wrapper = TIRWrapper(self.arch) self.lib = None - def get_source(self, target: Target = None) -> str: + def get_kernel_name_generator(self) -> Optional[BaseKernelNameGenerator]: + return None + + def get_source(self, target: Optional[Target] = None) -> str: if target is None: target = self.target + if self.lib_generator.lib_code is not None: + return self.lib_generator.lib_code if self.rt_mod is None: self._build_runtime_module(target) return self.rt_mod.imported_modules[0].get_source() if self.rt_mod else None @@ -88,7 +124,7 @@ def _build_runtime_module(self, target: Target): # Check if the platform is CUDA and we have an optimized function if self.arch.platform == "CUDA": - if self.optimized_func is None: + if self.optimized_mod is None: return None @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) @@ -96,17 +132,17 @@ def tvm_callback_cuda_postproc(code, _): return self.post_process(code) try: - # Use a specific TVM pass context for CUDA platforms with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, **self.pass_context }): - rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) + rt_mod = tvm.build(self.optimized_mod, target=target) except Exception: # noqa: F841 logger.debug( - "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" - ) + BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(self.__class__.__name__, target, + "optimized", + "Failed to build optimized module")) else: # For non-CUDA platforms or when no optimized function is available, build with the primary function rt_mod = tvm.build(self.prim_func, target=target, name=self.name) @@ -122,8 +158,8 @@ def tvm_callback_cuda_postproc(code, _): if self.arch.platform == "CUDA": try: is_dynamic = ( - self.dynamic_range is not None and len(self.optimized_func.functions) > 1) - self.wrapper.assign_optimized_module(self.optimized_func) + self.dynamic_range is not None and len(self.optimized_mod.functions) > 1) + self.wrapper.assign_optimized_module(self.optimized_mod) wrapped_source = self.wrapper.wrap(self.get_source(target), is_dynamic) self.lib_generator.update_lib_code(wrapped_source) self.lib_generator.compile_lib() @@ -153,14 +189,25 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule return optimized_mod return None + def _update_optimized_mod(self, optimized_mod: IRModule): + self.optimized_mod = optimized_mod + def _build_default_module(self, target: Target): try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None + scheduled_mod = self.apply_default_schedule(self.prim_func_mod, target) + assert len(scheduled_mod.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + assert "main" in scheduled_mod, ( + "The optimized module should have a function named 'main' for default schedule.") + default_kernal_name = self.kernel_name_generator.generate() + func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + optimized_mod = tvm.IRModule({default_kernal_name: func}) + self._update_optimized_mod(optimized_mod) + except Exception as apply_schedule_error: + self.optimized_mod = None logger.warning( - "[BitBLAS][Warning] Apply default schedule failed. Please perform hardware-aware tuning manually." - ) + APPLY_SCHEDULE_FAILED_MESSAGE.format(self.__class__.__name__, target, "default", + apply_schedule_error)) self._build_runtime_module(target) @@ -171,12 +218,13 @@ def apply_fast_tuning(self, func: PrimFunc, target: Target, topk: int = 20, - parallel_build=True) -> IRModule: + parallel_build=True) -> Tuple[IRModule, Hint]: _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) - if best is not None: - return best.sch.mod + # annotate the best pass context + # TODO(lei): actually we should remove this by enable pass through + # annotation in the func's attribute. self.pass_context = best.config.pass_context - return None + return ((best.sch.mod, best.config) if best is not None else (None, None)) def apply_fast_tuning_with_dynamic_range( self, @@ -186,25 +234,39 @@ def apply_fast_tuning_with_dynamic_range( dynamic_range: Dict[str, List[int]] = None, ): optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) + func, + target, + topk=topk, + parallel_build=True, + dynamic_range=dynamic_range, + kernel_name_generator=self.kernel_name_generator) if optimized_mod is not None: return optimized_mod return None def hardware_aware_finetune(self, topk: int = 20, - target: tvm.target.Target = None, + target: Optional[tvm.target.Target] = None, parallel_build=True): if target is None: target = self.target dynamic_range = self.dynamic_range func = self.prim_func if dynamic_range is not None: - self.optimized_func = self.apply_fast_tuning_with_dynamic_range( + self.optimized_mod = self.apply_fast_tuning_with_dynamic_range( func, target, topk, dynamic_range) else: - self.optimized_func = self.apply_fast_tuning( + scheduled_mod, best_hint = self.apply_fast_tuning( func, target, topk, parallel_build=parallel_build) + assert len(scheduled_mod.get_global_vars()) == 1, ( + "The optimized module should only have one global variable for default schedule.") + assert "main" in scheduled_mod, ( + "The optimized module should have a function named 'main' for default schedule.") + default_kernal_name = self.kernel_name_generator.generate(best_hint) + func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) + optimized_mod = tvm.IRModule({default_kernal_name: func}) + self._update_optimized_mod(optimized_mod) + self._build_runtime_module(self.target) def get_profile_tensors(self, dynamic_symbolic_constraints: Optional[Dict] = None): diff --git a/bitblas/wrapper/__init__.py b/bitblas/wrapper/__init__.py deleted file mode 100644 index 1d87f8020..000000000 --- a/bitblas/wrapper/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .general import CUDASourceWrapper, CUDASourceWrapperWithDynamic # noqa: F401 diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py deleted file mode 100644 index 4e7c65c2c..000000000 --- a/bitblas/wrapper/general.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from typing import Optional, List, Dict, Union -from tvm import IRModule -from bitblas import TileDevice -from tvm.runtime import ndarray -from bitblas.utils import match_global_kernel -import re -import ctypes -import os -import tempfile -import subprocess -import logging -from tvm.driver import lower -from tvm.target import Target - -logger = logging.getLogger(__name__) - -_TYPE_MAP = { - "float32": "float", - "float16": "half", - "bfloat16": "__nv_bfloat16", - "e4m3_float8": "__nv_fp8_e4m3", - "e5m2_float8": "__nv_fp8_e5m2", - "float64": "double", - "int64": "int64_t", - "int32": "int", - "uint32": "unsigned int", - "bool": "int8_t", - "int8": "int8_t", - "uint8": "uint8_t", - "int16": "int16_t", - "uchar": "uint8_t", -} - - -def get_annotated_device_mod(mod: IRModule, target: Target): - """ - Lower the given IRModule and create a device module for the specified target. - - Parameters: - - mod: The input IRModule. - - target: The compilation target. - - Returns: - - A device module ready for execution. - """ - input_mod = lower(mod) - target_input_mod = {target: input_mod} - annotated_mods = {} - runtime = None - target_host = None - for tgt, mod in target_input_mod.items(): - if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or " - "Target when inputs is dict.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule, " - "or dict of str to IRModule.") - annotated_mods[tgt] = mod.with_attr("runtime", runtime) - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - if not target_host: - for tar, _ in annotated_mods.items(): - device_type = ndarray.device(tar.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type: - target_host = tar - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - for target, mod in annotated_mods.items(): - mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") - device_mod_passes = tvm.get_global_func("driver.device_mod_passes") - mod = mixed_mod_passes(mod, target)(mod) - device_mod = device_mod_passes(mod, target)(mod) - return device_mod - - -def get_thread_block_information(mod: IRModule): - """ - Extracts the thread block and grid dimensions for the reduction block within a given IRModule. - - Parameters: - - mod: The input IRModule from which to extract thread block and grid information. - - Returns: - A tuple containing two lists: - - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). - - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). - """ - - # Initialize the schedule from the IRModule - sch = tvm.tir.Schedule(mod) - - # Get the root block and its child blocks - root_block = sch.get_block("root") - child_blocks = sch.get_child_blocks(root_block) - - # Initialize default block and grid dimensions (1, 1, 1) - block_dims, grid_dims = [1, 1, 1], [1, 1, 1] - - for block in child_blocks: - # Get the loops surrounding the main block - loops = sch.get_loops(block) - - # Iterate over each loop to extract thread and block bindings - for loop in loops: - stmt = sch.get(loop) - thread_binding = stmt.thread_binding - extent = int(stmt.extent) - - # Skip loops without thread binding - if thread_binding: - if "threadIdx" in thread_binding.thread_tag: - block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent - elif "blockIdx" in thread_binding.thread_tag: - grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent - - return block_dims, grid_dims - - -class CUDASourceWrapper(object): - - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - self.mod = optimized_mod - self.arch = arch - self.source = source - self.function_name: Optional[str] = None - self.dynamic_smem_buf: Optional[int] = None - self.block_info: Union[List[int], Dict] = [1, 1, 1] - self.grid_info: Union[List[int], Dict] = [1, 1, 1] - self.parse_source_information() - self.srcpath: Optional[str] = None - self.libpath: Optional[str] = None - self.lib_code: Optional[str] = self.update_lib_code(source) - - def load_lib(self): - return ctypes.CDLL(self.libpath) - - def remove_lib(self): - if self.libpath: - os.remove(self.libpath) - self.libpath = None - - def compile_lib(self, timeout: float = None): - arch = self.arch - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) - compute_version = arch.compute_capability - libpath = src.name.replace(".cu", ".so") - - command = [ - "nvcc", - "-std=c++17", - "-Xcudafe", - "--diag_suppress=177", - "--compiler-options", - "'-fPIC'", - "-lineinfo", - "--shared", - src.name, - "-lcuda", - "-gencode", - f"arch=compute_{compute_version},code=sm_{compute_version}", - "-o", - libpath, - ] - src.write(self.lib_code) - src.flush() - try: - ret = subprocess.run(command, timeout=timeout) - except subprocess.TimeoutExpired: - logger.warning(f"Compilation Timeout! {command}") - return None - if ret.returncode != 0: - logger.warning(f"Compilation Failed! {command}") - return None - self.srcpath = src.name - self.libpath = libpath - - def parse_source_information(self): - device_mod = get_annotated_device_mod(self.mod, self.arch.target) - assert (len(device_mod.functions) == 1 - ), "Only support one function in the module for static shape kernel." - for g_var, func in device_mod.functions.items(): - self.function_name = g_var.name_hint - attrs = func.attrs - if "dyn_shared_memory_buf" in attrs: - self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) - if "thread_extent" in attrs: - thread_extent = attrs["thread_extent"] - for tag, extent in thread_extent.items(): - if "threadIdx" in tag: - self.block_info["xyz".index(tag[-1])] = extent - elif "blockIdx" in tag: - self.grid_info["xyz".index(tag[-1])] = extent - - def get_dynamic_symbolic_set(self, prim_func): - # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set = set() - for param in prim_func.params: - buffer = prim_func.buffer_map[param] - for dim in buffer.shape: - if isinstance(dim, tvm.tir.Var): - dynamic_symbolic_set.add(dim.name) - return dynamic_symbolic_set - - def get_cuda_init_func(self): - # Initialize an empty string for the CUDA function call - call_str = """""" - # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call - if self.dynamic_smem_buf is not None: - call_str = """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(self.function_name, self.dynamic_smem_buf) - # Format the initialization function using the call_str - init_funcs = """ - extern "C" void init() {{ - {} - }} - """.format(call_str) - return init_funcs - - def update_lib_code(self, code: str): - # Update the library code with the given code string - self.lib_code = code - # Find the index of the global kernel function in the code - index = match_global_kernel(code) - # Extract the declaration of the function starting from the found index - declaration = code[index:].split(";")[0] - - function_name = self.function_name - # Get the CUDA initialization function - init_func = self.get_cuda_init_func() - - # Locate the opening brace of the function to insert arguments - index = code.index("{", index) - function_args = [] - # Populate the function arguments from the primary function's parameters and buffers - for param in self.prim_func.params: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", - }) - - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - # Add dynamic symbolic parameters as integers to the function arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) - - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) - # Format the function arguments for declaration - def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - - def func_call_args(s, function_args): - # Extract the function call arguments matching the function definition - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for match in matches: - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - return call_args - - call_args = ", ".join(func_call_args(declaration, function_args)) - block_info, grid_info = self.block_info, self.grid_info - - def legalize_c(p): - # Convert TIR expressions to legal C expressions - # Directly convert to string since the special case handling - # does not alter the string representation for `tvm.tir.Var` and `IntImm`. - # Replace Python's floor division operator with C's division operator - if isinstance(p, tvm.tir.IntImm): - p = int(p) - return str(p).replace("//", "/") - - # Prepare the block and grid dimensions for the CUDA kernel launch - block_str = "dim3({}, {}, {})".format( - legalize_c(block_info[0]), - legalize_c(block_info[1]), - legalize_c(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) - # Determine the shared memory size, defaulting to 0 if not specified - smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf - # Format the CUDA kernel launch string - if len(dynamic_symbolic_set) != 0: - call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) - else: - call_str = "" - call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, - smem_str, call_args) - # Create the host function wrapper for the CUDA kernel - host_func = """ - extern "C" void call({}) {{ - {} - }} - """.format(def_args, call_str) - # Combine the source, initialization function, and host function to form the complete library code - lib_code = self.source + init_func + host_func - return lib_code - - @property - def prim_func(self): - return self.mod["main"] - - -class CUDASourceWrapperWithDynamic(CUDASourceWrapper): - - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - super().__init__(optimized_mod, source, arch) - - def get_cuda_init_func(self): - # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory - call_str = """""" - # Iterate over functions and their dynamic shared memory requirements - for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): - if dynamic_smem_buf is not None: - # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(function_name, dynamic_smem_buf) - # Define the init function that will set the attributes for each kernel - init_funcs = """ -extern "C" void init() {{ - {} -}} - """.format(call_str) - return init_funcs - - def create_dispatch_func(self, code, function_informations): - # Extract the set of dynamic symbolic names used in the primary function - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - - # Find the location of the global kernel function in the code - index = match_global_kernel(code) - - # Analyze the function declaration to prepare for argument extraction - dummy_declaration = code[index:].split(";")[0] - - function_name = self.function_name - - # Identify the start of the function body to insert arguments - index = code.index("{", index) - function_args = [] - # Collect function arguments based on primary function's parameters and buffer mappings - for param in self.prim_func.params: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", - }) - # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) - - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) - - # Format the argument definitions for function declaration - def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - - def func_call_args(s: str, function_args): - # Extract and clean the function call arguments to match the declaration - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for match in matches: - match = re.sub(r"\d+", "", match) # Remove numbers - match = re.sub(r"_", "", match) # Remove underscores - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - return call_args - - call_args = ", ".join(func_call_args(dummy_declaration, function_args)) - - def legalize_c(p): - # Convert TIR expressions to legal C expressions - # Directly convert to string since the special case handling - # does not alter the string representation for `tvm.tir.Var` and `IntImm`. - # Replace Python's floor division operator with C's division operator - if isinstance(p, tvm.tir.IntImm): - p = int(p) - return str(p).replace("//", "/") - - last_range = 0 - num_items = len(function_informations) - _call_str = """""" - for function_name, info in function_informations.items(): - # Prepare block and grid configurations for kernel launches - block_info, grid_info = info["block_info"], info["grid_info"] - block_str = "dim3({}, {}, {})".format( - legalize_c(block_info[0]), - legalize_c(block_info[1]), - legalize_c(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - legalize_c(grid_info[0]), - legalize_c(grid_info[1]), - legalize_c(grid_info[2]), - ) - # Handle dynamic shared memory specification - smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) - opt_shapes = info["opt_shapes"] - # Generate conditional kernel launch code based on dynamic symbolic ranges - (symbolic,) = list(dynamic_symbolic_set) - range_str = opt_shapes[symbolic] - if last_range == 0: - call_str = "if ({} == 0) return; \n".format(symbolic,) - call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - symbolic, - range_str, - function_name, - grid_str, - block_str, - smem_str, - call_args, - ) - else: - call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - symbolic, - range_str, - function_name, - grid_str, - block_str, - smem_str, - call_args, - ) - if last_range == num_items - 1: - call_str += ( - "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - function_name, grid_str, block_str, smem_str, call_args)) - last_range += 1 - _call_str += call_str - - # Wrap the kernel dispatch logic in an external C function - host_func = """ -extern "C" void call({}) {{ - {} -}} - """.format(def_args, _call_str) - return host_func - - def parse_source_information(self): - # Parse device module to extract execution configurations for each function - device_mod = get_annotated_device_mod(self.mod, self.arch.target) - block_info_map = {} - grid_info_map = {} - dynamic_smem_buf_map = {} - for g_var, func in device_mod.functions.items(): - # Default block and grid configurations - block_info = [1, 1, 1] - grid_info = [1, 1, 1] - function_name = g_var.name_hint - attrs = func.attrs - dynamic_smem_buf = None - if "dyn_shared_memory_buf" in attrs: - dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) - if "thread_extent" in attrs: - # Extract block and grid sizes from thread extents - thread_extent = attrs["thread_extent"] - for tag, extent in thread_extent.items(): - if "threadIdx" in tag: - block_info["xyz".index(tag[-1])] = extent - elif "blockIdx" in tag: - grid_info["xyz".index(tag[-1])] = extent - # Map the extracted configurations to each function - block_info_map[function_name] = block_info - grid_info_map[function_name] = grid_info - dynamic_smem_buf_map[function_name] = dynamic_smem_buf - # Store the mappings for use in code generation - self.block_info = block_info_map - self.grid_info = grid_info_map - self.dynamic_smem_buf = dynamic_smem_buf_map - - def update_lib_code(self, code: str): - # Organize function information for code generation - function_informations = {} - for g_var, func in self.mod.functions.items(): - if g_var.name_hint == "main": - continue - function_name = g_var.name_hint - attrs = func.attrs - assert "opt_shapes" in attrs - opt_shapes = attrs["opt_shapes"] - function_informations[function_name] = { - "function_name": function_name, - "opt_shapes": opt_shapes, - "block_info": self.block_info[function_name], - "grid_info": self.grid_info[function_name], - "dynamic_smem_buf": self.dynamic_smem_buf[function_name], - } - - def compare_map_objects(map_obj): - comparable_representation = list(map_obj.values()) - return comparable_representation - - function_informations = dict( - sorted( - function_informations.items(), - key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) - - self.lib_code = code - - # Generate the initialization and dispatch functions - init_func = self.get_cuda_init_func() - host_func = self.create_dispatch_func(code, function_informations) - # Concatenate source code with generated code segments - lib_code = self.source + init_func + host_func - return lib_code - - @property - def prim_func(self): - return self.mod["main"] From af99e58b32a88e0ead58477f5dfc3b3e84a87ed2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 24 Aug 2024 04:15:07 +0000 Subject: [PATCH 11/12] Refactor TIR CUDA source wrapper for improved readability and maintainability --- bitblas/builder/wrapper/tir.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index 5d5763861..f0a549c37 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -279,8 +279,8 @@ def legalize_c(p): (symbolic,) = list(dynamic_symbolic_set) range_str = opt_shapes[symbolic] if last_range == 0: - call_str = "\tif ({} == 0) return; \n".format(symbolic,) - call_str += "\tif ({} <= {}) {{\n\t\t{}<<<{}, {}, {}, stream>>>({}); \n\t}}\n".format( + call_str = " if ({} == 0) return; \n".format(symbolic,) + call_str += " if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( symbolic, range_str, function_name, @@ -290,7 +290,7 @@ def legalize_c(p): call_args, ) else: - call_str = "\telse if ({} <= {}) {{\n\t\t{}<<<{}, {}, {}, stream>>>({}); \n\t}}\n".format( + call_str = " else if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( symbolic, range_str, function_name, @@ -300,7 +300,7 @@ def legalize_c(p): call_args, ) if last_range == num_items - 1: - call_str += ("\telse {{\n\t\t{}<<<{}, {}, {}, stream>>>({}); \n\t}}\n".format( + call_str += (" else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( function_name, grid_str, block_str, smem_str, call_args)) last_range += 1 _call_str += call_str From 2729aa8b5a7d3424552359e98b122e6fa8f275ad Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 24 Aug 2024 06:43:38 +0000 Subject: [PATCH 12/12] bug fix --- bitblas/builder/wrapper/tir.py | 11 ++++++----- bitblas/ops/general_matmul/__init__.py | 2 +- bitblas/ops/operator.py | 17 ++++++++++------- .../python/builder/test_backend_tir_builder.py | 5 +++-- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index f0a549c37..f39c7cfab 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -180,7 +180,12 @@ def legalize_c(p): @property def prim_func(self): - return self.mod["main"] + if len(self.mod.get_global_vars()) == 1: + return self.mod[self.mod.get_global_vars()[0]] + elif "main" in self.mod: + return self.mod["main"] + else: + raise ValueError("Unable to determine primary function.") class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper): @@ -377,10 +382,6 @@ def compare_map_objects(map_obj): lib_code = self.source + init_func + host_func return lib_code - @property - def prim_func(self): - return self.mod["main"] - class TIRWrapper(BaseWrapper): diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 0b55ef8f7..dfd22e6e8 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -278,7 +278,7 @@ def generate(self, hint=None) -> str: A_dtype = self.simplify_dtype(config.A_dtype) W_dtype = self.simplify_dtype(config.W_dtype) - precision_str = (f"A{A_dtype}W{W_dtype}") + precision_str = (f"{A_dtype}x{W_dtype}") kernel_name = "_".join([kernel_name, shape_str, precision_str]) # if config.with_scaling: diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 09aee625d..a94da9969 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -80,7 +80,6 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.dynamic_range = None self.pass_context: Dict = {} self.num_args = len(self.prim_func.params) - self.function_handle = None self.num_output_args: int = ( 1 # todo(lei): should be analyzed from the prim_func. ) @@ -93,10 +92,10 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): def get_kernel_name_generator(self) -> Optional[BaseKernelNameGenerator]: return None - def get_source(self, target: Optional[Target] = None) -> str: + def get_source(self, target: Optional[Target] = None, kenrel_only=False) -> str: if target is None: target = self.target - if self.lib_generator.lib_code is not None: + if self.lib_generator.lib_code is not None and not kenrel_only: return self.lib_generator.lib_code if self.rt_mod is None: self._build_runtime_module(target) @@ -153,14 +152,14 @@ def tvm_callback_cuda_postproc(code, _): # Initialize a time evaluator with the built module, specifying the device and the number of runs self.time_evaluator = rt_mod.time_evaluator( rt_mod.entry_name, self.arch.device, number=10) - self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if self.arch.platform == "CUDA": try: is_dynamic = ( self.dynamic_range is not None and len(self.optimized_mod.functions) > 1) self.wrapper.assign_optimized_module(self.optimized_mod) - wrapped_source = self.wrapper.wrap(self.get_source(target), is_dynamic) + wrapped_source = self.wrapper.wrap( + self.get_source(target, kenrel_only=True), is_dynamic) self.lib_generator.update_lib_code(wrapped_source) self.lib_generator.compile_lib() self.lib = self.lib_generator.load_lib() @@ -377,7 +376,6 @@ def update_func(self, func: PrimFunc): def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.rt_mod = rt_mod self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) - self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if srcpath is not None: assert self.lib_generator is not None, "lib_generator is not initialized" @@ -398,7 +396,12 @@ def _select_implementation(self) -> IRModule: @property def prim_func(self): - return self.prim_func_mod["main"] + if len(self.prim_func_mod.get_global_vars()) == 1: + return self.prim_func_mod[self.prim_func_mod.get_global_vars()[0]] + elif "main" in self.prim_func_mod: + return self.prim_func_mod["main"] + else: + raise ValueError("Unable to determine primary function.") @property def srcpath(self): diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index 22c134b12..f65ce8066 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -39,8 +39,9 @@ def matmul_backend_code_wrap( ) matmul = Matmul(config=matmul_config, enable_tuning=False) backend = TIRWrapper(arch=matmul.arch) - backend.assign_optimized_module(matmul.optimized_func) - wrapped_code = backend.wrap(matmul.get_source(), is_dynamic=isinstance(M, list)) + backend.assign_optimized_module(matmul.optimized_mod) + is_dynamic = (matmul.dynamic_range is not None and len(matmul.optimized_mod.functions) > 1) + wrapped_code = backend.wrap(matmul.get_source(kenrel_only=True), is_dynamic=is_dynamic) assert "void call" in wrapped_code