Skip to content

Mismatch output wth INT1 × float16 #276

@1773226512

Description

@1773226512

Hello everyone,

I am attempting to run a matrix multiplication using INT1 × float16 with W_dtype='uint1'. However, I encountered the same issue mentioned in #35: torch.testing.assert_close fails when W_dtype='uint1' but works correctly for other W_dtype values.
I tried resolving this by changing the version of BitBlas using pip install bitblas==0.0.1.dev4, but it didn’t work for me.

Environment
GPU: A800 80GB
BitBlas version: 0.0.1.dev4
PyTorch version: 2.2.1+cu121'
CUDA:12.4

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitblas
import torch




bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=1024,  # N dimension
    K=1024,  # K dimension
    A_dtype="float16",  # activation A dtype
    W_dtype="uint1",  # weight W dtype
    accum_dtype="float16",  # accumulation dtype
    out_dtype="float16",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda() 
weight_tensor = torch.randint(0, 2 , (1024, 1024), dtype=torch.int8).cuda() 

# Transform weight tensor to int4 data type
print(weight_tensor)
weight_tensor_int1 = matmul.transform_weight(weight_tensor)
print(weight_tensor_int1)
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int1)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)

```TVM target not found. Please set the TVM target environment variable using export TVM_TARGET=<target>, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
2024-12-24 05:09:03 [BitBLAS:WARNING]: TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
2024-12-24 05:09:03 [BitBLAS:INFO]: Auto detected target: cuda
2024-12-24 05:09:04 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
2024-12-24 05:09:17 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
WARNING:bitblas.utils.target_detector:TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, where is one of the available targets can be found in the output of `tools/get_available_targets.py`.
2024-12-24 05:09:20 [BitBLAS:DEBUG]: LocalBuilder: An exception occurred Traceback (most recent call last):
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/exec/popen_worker.py", line 87, in main
result = fn(args, kwargs)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/base/utils.py", line 212, in _build
rt_mod = tvm.build(mod, target=arch.target)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/driver/build_module.py", line 297, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in call
raise_last_ffi_error()
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/bitblas/3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
50: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue
)#1}> >::Call(tvm::runtime::PackedFuncObj const
, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue
)
49: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
48: tvm::codegen::Build(tvm::IRModule, tvm::Target)
47: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module ()(tvm::IRModule, tvm::Target)>(tvm::runtime::Module ()(tvm::IRModule, tvm::Target), std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
46: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
45: tvm::codegen::CodeGenC::AddFunction(tvm::GlobalVar const&, tvm::tir::PrimFunc const&)
44: tvm::codegen::CodeGenC::VisitStmt
(tvm::tir::DeclBufferNode const*)
43: tvm::codegen::CodeGenC::VisitStmt
(tvm::tir::DeclBufferNode const*)
42: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
41: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
40: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
39: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
38: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
37: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
36: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
35: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
34: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
33: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
32: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
31: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
30: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
29: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
28: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
27: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
26: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)
25: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::ForNode const*)
24: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::ForNode const*)
23: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
22: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)
21: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::ForNode const*)
20: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::ForNode const*)
19: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
18: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::BufferStoreNode const*)
17: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
16: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
15: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CastNode const*, std::ostream&)
14: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
13: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
12: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
11: tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
10: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
9: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::cxx11::basic_string<char, std::char_traits, std::allocator > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
8: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
7: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
6: tvm::codegen::CodeGenCUDA::VisitExpr
(tvm::tir::CallNode const*, std::ostream&)
5: tvm::codegen::CodeGenC::VisitExpr
(tvm::tir::CallNode const*, std::ostream&)
4: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
3: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::_cxx11::basic_string<char, std::char_traits, std::allocator > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
2: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
1: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
0: tvm::codegen::CodeGenCUDA::VisitExpr
(tvm::tir::RampNode const*, std::ostream&)
File "/root/BitBLAS/3rdparty/tvm/src/target/source/codegen_cuda.cc", line 1224
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.

2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.007 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.006 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.012 ms
2024-12-24 05:09:20 [BitBLAS:INFO]: Evaluation with config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}
2024-12-24 05:09:20 [BitBLAS:INFO]: Time cost of this config: 0.008 ms
Ref output: tensor([[269.7500, 251.8750, 250.3750, ..., 259.7500, 267.2500, 260.5000]],
device='cuda:0', dtype=torch.float16)
BitBLAS output: tensor([[216.0000, 205.8750, 203.2500, ..., 206.8750, 213.0000, 205.5000]],
device='cuda:0', dtype=torch.float16)
Traceback (most recent call last):
File "/home/chentianqi/delta_compression/kernel.py", line 44, in
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)
File "/home/chentianqi/miniconda3/envs/llava/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1520, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 1024 / 1024 (100.0%)
Greatest absolute difference: 65.5 at index (0, 111) (up to 1.0 allowed)
Greatest relative difference: 0.240234375 at index (0, 945) (up to 0.01 allowed)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions