-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
To reproduce a worse case:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulFineGrainScheduler,
)
import torch
import torch.backends
torch.manual_seed(0)
def assert_matmul_fine_grained_apply_config_correctness(
M,
N,
K,
trans_A=False,
trans_B=True,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
block_row_warps=1,
block_col_warps=1,
warp_row_tiles=16,
warp_col_tiles=16,
chunk=32,
num_stages=2,
enable_rasterization=False,
):
matmul = MatmulFineGrainScheduler(
M=M,
N=N,
K=K,
trans_A=trans_A,
trans_B=trans_B,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
).apply_config(
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
num_stages=num_stages,
enable_rasterization=enable_rasterization,
)
mod, params = tl.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1)
assert_matmul_fine_grained_apply_config_correctness(
768, 768, 768, False, True, "float16", "float16", "float16",
block_row_warps=1, block_col_warps=3, warp_row_tiles=16, warp_col_tiles=16, chunk=128, num_stages=0,
)The output is:
root@e01d939002c0:~/BitBLAS# /usr/bin/python /root/BitBLAS/debug/test_issue_parallel.py
Traceback (most recent call last):
File "/root/BitBLAS/debug/test_issue_parallel.py", line 138, in <module>
assert_matmul_fine_grained_apply_config_correctness(
File "/root/BitBLAS/debug/test_issue_parallel.py", line 56, in assert_matmul_fine_grained_apply_config_correctness
mod, params = tl.lower(matmul)
File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/tl/engine.py", line 84, in lower
mod = tl.transform.LayoutInference()(mod)
File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
56: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
55: tvm::transform::Pass::operator()(tvm::IRModule) const
54: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
53: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
52: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS_2tl15LayoutInferenceEvEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
51: tvm::tl::LayoutInferencer::Substitute(tvm::tir::PrimFunc)
50: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
49: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
48: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
47: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::BlockNode const*)
46: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
45: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
44: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
43: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
42: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
41: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
40: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
39: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
38: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
37: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
36: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
35: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
34: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
33: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
32: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
31: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
30: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
29: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
28: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
27: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
26: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
25: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
24: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
23: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
22: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
21: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::BlockNode const*)
20: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
19: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
17: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
16: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
15: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
14: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
13: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
12: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
11: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
10: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
9: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
8: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
6: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
3: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
2: tvm::tl::PartitionLoop(tvm::tir::For, tvm::tir::Var, tvm::arith::Analyzer*, tvm::tl::Fragment)
1: tvm::tl::FragmentNode::Inverse() const
0: tvm::tl::LayoutNode::Inverse() const
File "/root/BitBLAS/3rdparty/tvm/src/tl/layout/layout.cc", line 205
InternalError: Check failed: (res->errors.empty()) is false: ["The iterations do not traverse full iter space", "Index mapping does not form a bijective transform."]The problem lies in the parallel primitives:
for i in T.parallel(16):
for j in T.parallel(128):
A_shared[i, j] = A[x, y]while the thread num is 96, which is not divisible by 16x128.
However, this tile is the most efficient among the tile configurations and can be tensorized using our raw BitBlas TIR backend.
Metadata
Metadata
Assignees
Labels
No labels