Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Some of the key features of BitBLAS include:
- Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script).

## Latest News

- 11/04/2024 🚀🚀: We've supported high performance A INT4 x W INT4/INT2 Matmul.
- 10/02/2024 🚀🚀: We've added initial Flash Attention Ops and its implementation in Tilelang! Please refer to [PythonAPI](https://github.com/microsoft/BitBLAS/blob/main/docs/PythonAPI.md) and [QuickStart](https://github.com/microsoft/BitBLAS/blob/main/docs/QuickStart.md) docs and [PR #202](https://github.com/microsoft/BitBLAS/pull/202).
- 08/12/2024 🚀🚀: We've improved performance for contiguous batching. To enable it, you'll need to set specific flags. For more details, please refer to [PR #133](https://github.com/microsoft/BitBLAS/pull/133).
- 07/11/2024 ✨: Ladder is published and presented in OSDI'24. Please find [Ladder paper and presentation](https://www.usenix.org/conference/osdi24/presentation/wang-lei) if you are interested in the technical details of BitBLAS.
Expand Down Expand Up @@ -84,6 +84,8 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and
| INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) |
| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) |
| INT4 | INT4 | INT32 | FP32/FP16 | **√** | RTX 4090(SM_89) |
| INT4 | INT4 | INT32 | FP32/FP16 | **√** | RTX 4090(SM_89) |

We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR.

Expand Down
18 changes: 11 additions & 7 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@
{
// TODO(lei): uint4 sub should be enhanced.
// 0x03 0x03 0x03 0x03
i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i];
// i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i];
}
}
}
Expand Down Expand Up @@ -1625,7 +1625,7 @@ def initialize_tensor_intrin():


def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8"],
out_dtype: Literal["float16", "int8", "int4"],
source_format: Literal["int", "uint"] = "uint",
source_bit: int = 4,
storage_dtype: Literal["int32", "int8"] = "int8",
Expand All @@ -1644,8 +1644,8 @@ def get_lop3_intrin_group(
in_dtype : Literal["int8"]
The data type of the input. It should be "int8".

out_dtype : Literal["float16", "int8"]
The data type of the output. It can be either "float16" or "int8".
out_dtype : Literal["float16", "int8", "int4"]
The data type of the output. It can be either "float16" or "int8" or "int4".

storage_nbit : int, optional
The number of bits used for storage. By default, it is 4.
Expand All @@ -1667,10 +1667,11 @@ def get_lop3_intrin_group(
Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations.
"""
assert out_dtype in ["float16",
"int8"], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8'.")
assert out_dtype in [
"float16", "int8", "int4"
], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .")

dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"}
dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"}
target_dtype = dtype_mapping[out_dtype]
target_bits = tvm.DataType(out_dtype).bits
loop_extent = 128 // target_bits
Expand Down Expand Up @@ -1707,6 +1708,7 @@ def get_lop3_intrin_group(
"i1_to_i8": decode_i1s_to_i8s,
"i2_to_i8": decode_i2s_to_i8s,
"i4_to_i8": decode_i4s_to_i8s,
"i2_to_i4": decode_i2s_to_i4s,
}
key = f"i{source_bit}_to_{target_dtype}"
if with_scaling:
Expand All @@ -1722,6 +1724,8 @@ def get_lop3_intrin_group(
d4f = "f16"
elif out_dtype == "int8":
d4f = "i8s"
elif out_dtype == "int4":
d4f = "i4s"
else:
raise ValueError("Unsupported target dtype: {}".format(target_dtype))
source_symbol = "u" if source_format == "uint" else "s"
Expand Down
18 changes: 17 additions & 1 deletion bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm
from tvm import DataType
from tvm.target import Target
import operator
from functools import reduce
Expand Down Expand Up @@ -35,6 +36,9 @@
("float16", "float16"),
("bfloat16", "bfloat16"),
("int8", "int8"),
("uint8", "uint8"),
("int4", "int4"),
("uint4", "uint4"),
("e4m3_float8", "e4m3_float8"),
("e4m3_float8", "e5m2_float8"),
("e5m2_float8", "e4m3_float8"),
Expand Down Expand Up @@ -142,6 +146,11 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind],
if self.A_dtype in ["e4m3_float8", "e5m2_float8", "bfloat16"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)
if self.A_dtype in ["int4", "uint4"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
# TODO(lei): tl doesn't implement IntraWarpTransform
if self.propagate_b == TransformKind.IntraWarpTransform:
object.__setattr__(self, "propagate_b", TransformKind.LDMatrixTransform)

# TODO(lei): propagation can only be enabled on SM80+ Devices and MI200+
# We should add a check here to disable the propagation if the device is not supported.
Expand Down Expand Up @@ -358,6 +367,10 @@ def __init__(

self.source_format = source_format
self.bit = bit

# This is a hack to support the int4 and uint4
if config.A_dtype in ["int4", "uint4"]:
backend = "tl"
super().__init__(name, config, target, backend)

if source_format == "int" and self.with_zeros:
Expand Down Expand Up @@ -471,10 +484,13 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool):
# weight transform should be done in the unpacked level
# otherwise the bit trick should be applied and that is
# too complex to be implemented in the ladder permutation.
datatype = self.A_dtype
if DataType(datatype).bits < 8:
datatype = self.storage_dtype
ladder_permutate_config = LadderPermutateConfig(
M=self.N,
N=self.K,
datatype=self.A_dtype,
datatype=datatype,
dequantize_bits=-1,
storage_dtype=self.storage_dtype,
propagate_kind="B",
Expand Down
16 changes: 14 additions & 2 deletions bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
MatmulWeightPropagationScheduler, # noqa: F401
)

from .matmul_tensorcore_s4 import (
MatmulINT4FineGrainScheduler, # noqa: F401
MatmulINT4WeightPropagationScheduler, # noqa: F401
)

from bitblas.ops.common import TransformKind
from typing import Union

Expand Down Expand Up @@ -82,8 +87,13 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag
conditions.append(propagate_b == TransformKind.LDMatrixTransform)
return all(conditions)

def is_int4_dtype(dtype):
return dtype == "int4" or dtype == "uint4"

if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b):
return MatmulWeightPropagationScheduler(
Scheduler = MatmulWeightPropagationScheduler if not is_int4_dtype(
in_dtype) else MatmulINT4WeightPropagationScheduler
return Scheduler(
M=M,
N=N,
K=K,
Expand All @@ -94,7 +104,9 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag
accum_dtype=accum_dtype,
)
if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b):
return MatmulFineGrainScheduler(
Scheduler = MatmulFineGrainScheduler if not is_int4_dtype(
in_dtype) else MatmulINT4FineGrainScheduler
return Scheduler(
M=M,
N=N,
K=K,
Expand Down
33 changes: 21 additions & 12 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ def with_default_config(self):
warp_row_tiles = getattr(self, "warp_row_tiles", 32)
warp_col_tiles = getattr(self, "warp_col_tiles", 32)
chunk = getattr(self, "chunk", 32)
# Swizzle size for INT8 Storage is 64
if DataType(self.in_dtype).bits <= 8:
chunk = 64
num_stages = getattr(self, "num_stages", 2)
enable_rasterization = getattr(self, "enable_rasterization", False)

Expand Down Expand Up @@ -597,7 +600,9 @@ def apply_config(
threads = warp_size * (block_row_warps * block_col_warps)

# Calculate local fragment sizes for tensor core
local_size = (micro_size_x * micro_size_y) // warp_size
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

Expand Down Expand Up @@ -633,9 +638,9 @@ def main(
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

# Thread-level parallelism for Tensor Cores
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
Expand Down Expand Up @@ -809,7 +814,9 @@ def matmul_macro_tensorcore(

warp_size = 32 # nvidia gpu warp size is 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

Expand Down Expand Up @@ -838,9 +845,9 @@ def main(
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
Expand Down Expand Up @@ -947,7 +954,9 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix(

warp_size = 32 # nvidia gpu warp size is 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

Expand Down Expand Up @@ -977,9 +986,9 @@ def main(
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
Expand Down
Loading
Loading