Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

This pull request introduces several significant updates and additions to the bitblas library, particularly focusing on matrix multiplication (matmul) operations. Key changes include the implementation of a new matmul function for dequantized weights, the addition of various matmul schedulers, and the inclusion of comprehensive testing for these schedulers.

New Features and Implementations:

  • Matmul Function for Dequantized Weights:

    • Implemented matmul_blocked_weight_only function in bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py to handle matmul operations with dequantized weights.
  • Matmul Schedulers:

    • Added MatmulScheduler, MatmulFineGrainScheduler, and MatmulWeightPropagationScheduler in bitblas/ops/general_matmul/tilelang/dense/__init__.py for different matmul scheduling strategies.

Testing Enhancements:

  • New Test Cases:
    • Added extensive test cases in testing/python/operators/test_general_matmul_tilelang_kernel.py to validate the correctness and performance of the new matmul schedulers.

Code Quality Improvements:

  • License Additions:

    • Added missing license headers to several files, including bitblas/ops/general_matmul/tilelang/dequantize/__init__.py and bitblas/ops/general_matmul/tilelang/dequantize/matmul_weight_only.py. [1] [2]
  • Code Refactoring:

    • Refactored local memory allocation in testing/python/tilelang/test_tilelang_dequantize_gemm.py to use T.alloc_local instead of T.alloc_fragment. [1] [2]
    • Removed unnecessary print statements and added assertions to ensure outputs are not None in testing/python/tilelang/test_tilelang_dequantize_gemm.py.

These changes collectively enhance the functionality, maintainability, and reliability of the bitblas library, particularly in the context of matrix multiplication operations

@dataclass
class MatmulScheduler:

    # OP Related Config
    M: int
    N: int
    K: int
    trans_A: bool = False
    trans_B: bool = False
    dtypeAB: str = "float16"
    dtypeC: str = "float16"
    accum_dtype: str = "float16"

    # Default Tile Related Params
    block_M: int = 64
    block_N: int = 64
    block_K: int = 32
    num_stages: int = 2
    threads: int = 128
    enable_rasterization: bool = False  # Enhance L2 Locality

    def with_default_config(self):
        block_M = getattr(self, "block_M", 64)
        block_N = getattr(self, "block_N", 64)
        block_K = getattr(self, "block_K", 32)
        num_stages = getattr(self, "num_stages", 2)
        threads = getattr(self, "threads", 128)
        enable_rasterization = getattr(self, "enable_rasterization", False)

        return self.apply_config(
            block_M=block_M,
            block_N=block_N,
            block_K=block_K,
            num_stages=num_stages,
            threads=threads,
            enable_rasterization=enable_rasterization,
        )

    def apply_config(
        self,
        block_M=64,
        block_N=64,
        block_K=32,
        num_stages=2,
        threads=128,
        # Enhance L2 Locality
        enable_rasterization=False,
    ):
        M, N, K = self.M, self.N, self.K
        trans_A, trans_B = self.trans_A, self.trans_B
        dtypeAB, dtypeC, accum_dtype = self.dtypeAB, self.dtypeC, self.accum_dtype

        A_shape = (K, M) if trans_A else (M, K)
        B_shape = (N, K) if trans_B else (K, N)
        A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
        B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

        @T.prim_func
        def main(
                A: T.Buffer(A_shape, dtypeAB),
                B: T.Buffer(B_shape, dtypeAB),
                C: T.Buffer((M, N), dtypeC),
        ):
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
                A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
                B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

                if enable_rasterization:
                    # rasterization factor
                    T.use_swizzle(10)

                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    if trans_A:
                        T.copy(A[k * block_K, by * block_M], A_shared)
                    else:
                        T.copy(A[by * block_M, k * block_K], A_shared)
                    if trans_B:
                        T.copy(B[bx * block_N, k * block_K], B_shared)
                    else:
                        T.copy(B[k * block_K, bx * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
                T.copy(C_local, C[by * block_M, bx * block_N])

        return main

    def __post_init__(self):
        # Add Config Validation
        return

@LeiWang1999 LeiWang1999 changed the title [TL] Warp TL Kernel with Scheduler [TL] Wrap TL Kernel with Scheduler Sep 28, 2024
@LeiWang1999
Copy link
Contributor Author

Implement class BaseScheduler:

@dataclass
class BaseScheduler:

    enable_simplify: bool = True

    @staticmethod
    def Simplify(stmt: Union[PrimFunc, IRModule]):
        if isinstance(stmt, PrimFunc):
            return tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(stmt))["main"]
        elif isinstance(stmt, IRModule):
            return tvm.tir.transform.Simplify()(stmt)
        else:
            raise ValueError(f"Unsupported type: {type(stmt)}")

    def enable_simplify(self):
        self.enable_simplify = True
        return self

    def disable_simplify(self):
        self.enable_simplify = False
        return self

    def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]):
        if self.enable_simplify:
            return self.Simplify(stmt)
        return stmt

To wrap common class methods.

matmul = MatmulScheduler(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        dtypeAB=dtypeAB,
        dtypeC=dtypeC,
        accum_dtype=accum_dtype,
    ).disable_simplify().with_default_config()
    
    simplified = MatmulScheduler.Simplify(matmul)

Before applying simplification:

@T.prim_func
def main(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), C: T.Buffer((128, 128), "float16")):
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 2)
    by = T.launch_thread("blockIdx.y", 2)
    v = T.launch_thread("threadIdx.x", 128)
    v_1 = T.launch_thread("threadIdx.y", 1)
    v_2 = T.launch_thread("threadIdx.z", 1)
    with T.block(""):
        T.reads(A[T.min(0, by * 64):T.min(0, by * 64) + (T.max(159, by * 64 + 63) + 1 - T.min(0, by * 64)), T.min(by, 0) * 64:T.min(by, 0) * 64 + (T.max(by * 64 + 31, 127) + 1 - T.min(by, 0) * 64)], B[T.min(bx, 0) * 64:T.min(bx, 0) * 64 + (T.max(bx * 64 + 63, 159) + 1 - T.min(bx, 0) * 64), T.min(0, bx * 64):T.min(0, bx * 64) + (T.max(127, bx * 64 + 31) + 1 - T.min(0, bx * 64))])
        T.writes(C[by * 64:by * 64 + 64, bx * 64:bx * 64 + 64])
        A_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        C_local = T.alloc_buffer((64, 64), "float16", scope="local.fragment")
        if T.bool(False):
            T.attr(None, "threadblock_swizzle_pattern", "tl::rasterization2DRow<10>")
            T.evaluate(0)
        T.fill(T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 2), 0)
        for k in T.serial(4, annotations={"num_stages": 2}):
            if T.bool(False):
                T.copy(T.region(A[k * 32, by * 64], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
            else:
                T.copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
            if T.bool(True):
                T.copy(T.region(B[bx * 64, k * 32], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
            else:
                T.copy(T.region(B[k * 32, bx * 64], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
            T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 3), T.bool(False), T.bool(True), 64, 64, 32, 0)
        T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64))

After

# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), C: T.Buffer((128, 128), "float16")):
    # with T.block("root"):
    bx = T.launch_thread("blockIdx.x", 2)
    by = T.launch_thread("blockIdx.y", 2)
    v = T.launch_thread("threadIdx.x", 128)
    v_1 = T.launch_thread("threadIdx.y", 1)
    v_2 = T.launch_thread("threadIdx.z", 1)
    with T.block(""):
        T.reads(A[0:160, 0:128], B[0:160, 0:128])
        T.writes(C[by * 64:by * 64 + 64, bx * 64:bx * 64 + 64])
        A_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer((64, 32), "float16", scope="shared.dyn")
        C_local = T.alloc_buffer((64, 64), "float16", scope="local.fragment")
        T.fill(T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 2), 0)
        for k in T.serial(4, annotations={"num_stages": 2}):
            T.copy(T.region(A[by * 64, k * 32], 1, 64, 32), T.region(A_shared[0, 0], 2, 64, 32))
            T.copy(T.region(B[bx * 64, k * 32], 1, 64, 32), T.region(B_shared[0, 0], 2, 64, 32))
            T.gemm(T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), C_local.data, 0, 4096, 3), T.bool(False), T.bool(True), 64, 64, 32, 0)
        T.copy(T.region(C_local[0, 0], 1, 64, 64), T.region(C[by * 64, bx * 64], 2, 64, 64))

@LeiWang1999 LeiWang1999 merged commit cd41b4e into microsoft:main Sep 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant