Skip to content

[Feature Request] Parallel Primitive Should be enhanced to improve the performance for irregular shapes #209

@LeiWang1999

Description

@LeiWang1999

To reproduce a worse case:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
    MatmulFineGrainScheduler,
)

import torch
import torch.backends

torch.manual_seed(0)


def assert_matmul_fine_grained_apply_config_correctness(
    M,
    N,
    K,
    trans_A=False,
    trans_B=True,
    in_dtype="float16",
    out_dtype="float16",
    accum_dtype="float16",
    block_row_warps=1,
    block_col_warps=1,
    warp_row_tiles=16,
    warp_col_tiles=16,
    chunk=32,
    num_stages=2,
    enable_rasterization=False,
):

    matmul = MatmulFineGrainScheduler(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        in_dtype=in_dtype,
        out_dtype=out_dtype,
        accum_dtype=accum_dtype,
    ).apply_config(
        block_row_warps=block_row_warps,
        block_col_warps=block_col_warps,
        warp_row_tiles=warp_row_tiles,
        warp_col_tiles=warp_col_tiles,
        chunk=chunk,
        num_stages=num_stages,
        enable_rasterization=enable_rasterization,
    )

    mod, params = tl.lower(matmul)
    src_code = mod.imported_modules[0].get_source()

    # src_code is the generated cuda source
    assert src_code is not None

    A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
    B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
    C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

    mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)

    mod(A, B, C)

    latency = mod.do_bench(mod.func, warmup=25)

    # Ensure that the latency is not None
    assert latency is not None

    # Get Reference Result
    ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
    torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1)

assert_matmul_fine_grained_apply_config_correctness(
    768, 768, 768, False, True, "float16", "float16", "float16",
    block_row_warps=1, block_col_warps=3, warp_row_tiles=16, warp_col_tiles=16, chunk=128, num_stages=0,
)

The output is:

root@e01d939002c0:~/BitBLAS# /usr/bin/python /root/BitBLAS/debug/test_issue_parallel.py
Traceback (most recent call last):
  File "/root/BitBLAS/debug/test_issue_parallel.py", line 138, in <module>
    assert_matmul_fine_grained_apply_config_correctness(
  File "/root/BitBLAS/debug/test_issue_parallel.py", line 56, in assert_matmul_fine_grained_apply_config_correctness
    mod, params = tl.lower(matmul)
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/tl/engine.py", line 84, in lower
    mod = tl.transform.LayoutInference()(mod)
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  56: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  55: tvm::transform::Pass::operator()(tvm::IRModule) const
  54: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  53: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  52: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS_2tl15LayoutInferenceEvEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  51: tvm::tl::LayoutInferencer::Substitute(tvm::tir::PrimFunc)
  50: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  49: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  48: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  47: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::BlockNode const*)
  46: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
  45: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  44: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  43: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  42: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  41: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  40: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  39: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  38: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  37: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  36: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  35: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  34: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  33: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  32: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  31: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  30: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  29: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  28: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  27: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  26: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  25: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  24: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  23: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  22: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  21: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::BlockNode const*)
  20: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
  19: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  17: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
  16: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  15: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  14: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  13: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  12: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  11: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  10: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
  9: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  8: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  6: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  3: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
  2: tvm::tl::PartitionLoop(tvm::tir::For, tvm::tir::Var, tvm::arith::Analyzer*, tvm::tl::Fragment)
  1: tvm::tl::FragmentNode::Inverse() const
  0: tvm::tl::LayoutNode::Inverse() const
  File "/root/BitBLAS/3rdparty/tvm/src/tl/layout/layout.cc", line 205
InternalError: Check failed: (res->errors.empty()) is false: ["The iterations do not traverse full iter space", "Index mapping does not form a bijective transform."]

The problem lies in the parallel primitives:

for i in T.parallel(16):
 for j in T.parallel(128):
      A_shared[i, j] = A[x, y]

while the thread num is 96, which is not divisible by 16x128.

However, this tile is the most efficient among the tile configurations and can be tensorized using our raw BitBlas TIR backend.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions