-
Notifications
You must be signed in to change notification settings - Fork 52
Description
Hello everyone,
I am attempting to run a matrix multiplication using INT1 × float16 with W_dtype='uint1'. However, I encountered the same issue mentioned in #35: torch.testing.assert_close fails when W_dtype='uint1' but works correctly for other W_dtype values.
I tried resolving this by changing the version of BitBlas using pip install bitblas==0.0.1.dev4, but it didn’t work for me.
Environment
GPU: A800 80GB
BitBlas version: 0.0.1.dev4
PyTorch version: 2.2.1+cu121'
CUDA:12.4
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitblas
import torch
bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
K=1024, # K dimension
A_dtype="float16", # activation A dtype
W_dtype="uint1", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)
# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 2 , (1024, 1024), dtype=torch.int8).cuda()
# Transform weight tensor to int4 data type
print(weight_tensor)
weight_tensor_int1 = matmul.transform_weight(weight_tensor)
print(weight_tensor_int1)
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int1)
# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)```TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
2024-12-24 05:09:03 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
2024-12-24 05:09:03 [BitBLAS:INFO]: Auto detected target: cuda
2024-12-24 05:09:04 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
2024-12-24 05:09:20 [BitBLAS:DEBUG]: LocalBuilder: An exception occurred Traceback (most recent call last):
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/exec/popen_worker.py", line 87, in main
result = fn(args, kwargs)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/base/utils.py", line 212, in _build
rt_mod = tvm.build(mod, target=arch.target)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/driver/build_module.py", line 297, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in call
raise_last_ffi_error()
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
50: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)
49: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
48: tvm::codegen::Build(tvm::IRModule, tvm::Target)
47: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module ()(tvm::IRModule, tvm::Target)>(tvm::runtime::Module ()(tvm::IRModule, tvm::Target), std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
46: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
45: tvm::codegen::CodeGenC::AddFunction(tvm::GlobalVar const&, tvm::tir::PrimFunc const&)
44: tvm::codegen::CodeGenC::VisitStmt(tvm::tir::DeclBufferNode const*)
43: tvm::codegen::CodeGenC::VisitStmt(tvm::tir::DeclBufferNode const*)
42: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
41: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
40: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
39: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
38: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
37: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
36: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
35: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
34: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
33: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
32: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
31: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
30: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
29: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
28: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
27: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
26: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)
25: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::ForNode const*)
24: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::ForNode const*)
23: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
22: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)
21: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::ForNode const*)
20: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::ForNode const*)
19: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
18: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::BufferStoreNode const*)
17: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
16: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
15: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CastNode const*, std::ostream&)
14: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
13: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
12: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
11: tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
10: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
9: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::cxx11::basic_string<char, std::char_traits, std::allocator > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
8: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
7: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
6: tvm::codegen::CodeGenCUDA::VisitExpr(tvm::tir::CallNode const*, std::ostream&)
5: tvm::codegen::CodeGenC::VisitExpr(tvm::tir::CallNode const*, std::ostream&)
4: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
3: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::_cxx11::basic_string<char, std::char_traits, std::allocator > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
2: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
1: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
0: tvm::codegen::CodeGenCUDA::VisitExpr(tvm::tir::RampNode const*, std::ostream&)
File "/root/BitBLAS/3rdparty/tvm/src/target/source/codegen_cuda.cc", line 1224
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.007 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.012 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.008 ms
Ref output: tensor([[269.7500, 251.8750, 250.3750, ..., 259.7500, 267.2500, 260.5000]],
device='cuda:0', dtype=torch.float16)
BitBLAS output: tensor([[216.0000, 205.8750, 203.2500, ..., 206.8750, 213.0000, 205.5000]],
device='cuda:0', dtype=torch.float16)
Traceback (most recent call last):
File "/home/chentianqi/delta_compression/kernel.py", line 44, in
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1520, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 1024 / 1024 (100.0%)
Greatest absolute difference: 65.5 at index (0, 111) (up to 1.0 allowed)
Greatest relative difference: 0.240234375 at index (0, 945) (up to 0.01 allowed)