-
Notifications
You must be signed in to change notification settings - Fork 52
Closed
Description
Thanks for the great project - I was wondering if the repo supports int1 x float16 matmul, and if so how should I go about it? For reference, I'm trying to strengthen the kernel in this repo.
My attempt:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitblas
import torch
os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/"
bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
K=1024, # K dimension
A_dtype="float16", # activation A dtype
W_dtype="int1", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)
# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda()
# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)
# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)
Output:
2024-05-03 11:57:37 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-03 11:57:37 [BitBLAS:DEBUG]: [BitBLAS][Error] applying rule <bitblas.gpu.gemv.GEMV object at 0x7ff5038880d0> failed
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
...
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.
Metadata
Metadata
Assignees
Labels
No labels