Skip to content

Is int1 x float16 supported? #35

@chromecast56

Description

@chromecast56

Thanks for the great project - I was wondering if the repo supports int1 x float16 matmul, and if so how should I go about it? For reference, I'm trying to strengthen the kernel in this repo.

My attempt:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitblas
import torch


os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/"

bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=1024,  # N dimension
    K=1024,  # K dimension
    A_dtype="float16",  # activation A dtype
    W_dtype="int1",  # weight W dtype
    accum_dtype="float16",  # accumulation dtype
    out_dtype="float16",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda()

# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)

Output:

2024-05-03 11:57:37 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-03 11:57:37 [BitBLAS:DEBUG]: [BitBLAS][Error] applying rule <bitblas.gpu.gemv.GEMV object at 0x7ff5038880d0> failed
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
...
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions