Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ca42750
fix for relax
LeiWang1999 Dec 8, 2024
58fa7bf
lint fix
LeiWang1999 Dec 8, 2024
8275513
save import bitblas time
LeiWang1999 Dec 10, 2024
fb7de9b
bug fix for tl backend
LeiWang1999 Dec 10, 2024
02cf643
support input transform_kind
LeiWang1999 Dec 11, 2024
65fb3b4
hint identifier
LeiWang1999 Dec 11, 2024
ad7bc1c
annotate hint type for dequantize
LeiWang1999 Dec 11, 2024
d635713
enhance swizzling
LeiWang1999 Dec 12, 2024
a3e97de
Enhance for hardware aware tuning
LeiWang1999 Dec 12, 2024
bdbc685
test fix
LeiWang1999 Dec 12, 2024
e30b64f
remove pad factor
LeiWang1999 Dec 13, 2024
3b2646a
introduce legalize dyanmic pass
LeiWang1999 Dec 13, 2024
b44e42f
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lega…
LeiWang1999 Dec 13, 2024
9462884
update 3rdparty
LeiWang1999 Dec 16, 2024
d662748
testfix
LeiWang1999 Dec 16, 2024
8c05d7b
test code commit
LeiWang1999 Dec 16, 2024
cdd0753
enhance typing and fix test for int4 dequantize gemm
LeiWang1999 Dec 16, 2024
b9c343c
lint fix
LeiWang1999 Dec 16, 2024
bf6903a
TEST FIX
LeiWang1999 Dec 16, 2024
ab0fef2
lint fix
LeiWang1999 Dec 16, 2024
f5e036c
Merge branch 'main' of https://github.com/microsoft/BitBLAS into chan…
LeiWang1999 Dec 16, 2024
ee770d4
Bugfix for bias
LeiWang1999 Dec 16, 2024
7a45262
lint fix
LeiWang1999 Dec 16, 2024
c48302d
lint fix
LeiWang1999 Dec 16, 2024
0b11dfe
test fix
LeiWang1999 Dec 16, 2024
6d1a7e4
Implement Bias
LeiWang1999 Dec 17, 2024
e729caa
fallback nf to tir implementation.
LeiWang1999 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,8 @@ def _score(node, thread): # small is better
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap):
debug_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \
" use dynamic shared memory."
logger.debug(debug_message)
# Tile Dict: {td.output_tile} Shared memory exceeds the static capacity
# use dynamic shared memory.
codegen_dict.shared_scope = "shared.dyn"

codegen_dict.shared_scope = "shared.dyn"
Expand Down
6 changes: 3 additions & 3 deletions bitblas/builder/wrapper/tl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "__nv_bfloat16",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"bfloat16": "bfloat16_t",
"e4m3_float8": "float_e4m3_t",
"e5m2_float8": "float_e5m2_t",
"float64": "double",
"int64": "int64_t",
"int32": "int",
Expand Down
2 changes: 1 addition & 1 deletion bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(self, A, output=None):
self.init_params()
args = [A_void, *self.q_params]
if output is None:
output = torch.empty(
output = torch.zeros(
A.shape[:-1] + (self.out_features,),
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
device=A.device)
Expand Down
37 changes: 29 additions & 8 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .tilelang.dense import select_scheduler as consistent_scheduler
from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils import retrieve_func_from_module
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
from ..ladder_permutate import LadderPermutate, LadderPermutateConfig
Expand Down Expand Up @@ -350,7 +351,7 @@ def __init__(
target: Optional[Union[str, Target]] = None,
enable_tuning: bool = True,
from_database: bool = False,
backend: str = "tir",
backend: str = "tl",
):
# if from database, we should disable default schedule
# to save compilation time
Expand All @@ -370,8 +371,14 @@ def __init__(
self.bit = bit

# This is a hack to support the int4 and uint4
# legalize the backend (hacky implementation)
# TODO(lei): In future release we should remove
# by implementing all the operators in the tl backend.
if config.A_dtype in ["int4", "uint4"]:
backend = "tl"
if source_format in ["nf"]:
backend = "tir"

super().__init__(name, config, target, backend)

if source_format == "int" and self.with_zeros:
Expand All @@ -383,13 +390,13 @@ def __init__(
if target.kind.name not in ("cuda", "hip"):
raise ValueError("Currently only support cuda and hip target")

self.dispatch_tir(target, from_database, source_format, enable_tuning)
self.dispatch(target, from_database, source_format, enable_tuning)

def dispatch_tir(self,
target: Target,
from_database: bool = False,
source_format: str = "uint",
enable_tuning: bool = True):
def dispatch(self,
target: Target,
from_database: bool = False,
source_format: str = "uint",
enable_tuning: bool = True):

if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
Expand Down Expand Up @@ -638,7 +645,21 @@ def post_process(self, code: str) -> str:
return code

def retrieve_weight_shape(self):
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]
prim_func = self.prim_func

# retrieve from tilelang backend
if prim_func is None and self.scheduled_ir_module is not None:
prim_func = retrieve_func_from_module(self.scheduled_ir_module)

if prim_func is None and self.is_tilelang_backend():
# If from_database and from tilelang backend, we should construct a default module
self._update_optimized_mod(self.scheduler_with_default(self.scheduler))
prim_func = retrieve_func_from_module(self.scheduled_ir_module)

if prim_func is not None:
return [int(i) for i in prim_func.buffer_map[prim_func.params[1]].shape]

raise ValueError("The weight shape is not available.")

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Expand Down
5 changes: 2 additions & 3 deletions bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def apply_config(

A_shape = (M, K)
B_shape = (N, K)
C_shape = (M, N)
Bias_shape = (N,)
C_shape = (M, N)

dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
Expand All @@ -121,8 +121,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
Expand Down Expand Up @@ -186,5 +186,4 @@ def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"
assert self.with_bias is False, "Currently only support without bias"
return
1 change: 0 additions & 1 deletion bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"
assert self.with_bias is False, "Currently only support without bias"
assert self.input_transform_kind == TransformKind.NonTransform, "Currently only support NonTransform for input"

return
Expand Down
31 changes: 21 additions & 10 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_require_cache(self) -> bool:

conditions: List[bool] = []
conditions.append(False)
# Bias Add should be done in shared memory
# Bias Add should be performed in shared memory
conditions.append(with_bias)
return any(conditions) # Always set to False Currently

Expand Down Expand Up @@ -172,6 +172,8 @@ def apply_config(
self.accum_dtype,
)

with_bias = self.with_bias

shared_scope = "shared.dyn"

block_M = block_size_x * thread_row_tiles
Expand All @@ -183,6 +185,7 @@ def apply_config(
C_shape = (M, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
Bias_shape = (N,)

threads = thread_row_tiles * thread_col_tiles
local_size_a = block_M // thread_row_tiles
Expand All @@ -198,6 +201,7 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
Expand Down Expand Up @@ -249,21 +253,28 @@ def main(
else:
for dp4a_idx in T.serial(dp4a_size):
C_local[i * local_size_b + j] += (
A_local[i, mk * dp4a_size + dp4a_idx] *
B_local[j, mk * dp4a_size + dp4a_idx])

for i, j in T.grid(local_size_a, local_size_b):
C[
by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j,
] = C_local[i * local_size_b + j]
A_local[i,
mk * dp4a_size + dp4a_idx].astype(accum_dtype) *
B_local[j,
mk * dp4a_size + dp4a_idx].astype(accum_dtype))

if with_bias:
for i, j in T.grid(local_size_a, local_size_b):
C_local[i * local_size_b + j] += Bias[bx * block_N + warp_n * local_size_b +
j]

for i in T.serial(local_size_a):
for j in T.vectorized(local_size_b):
C[
by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j,
] = C_local[i * local_size_b + j]

return self.post_process(main)

def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"
assert self.with_bias is False, "Currently only support without bias"

return
18 changes: 11 additions & 7 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def check_require_cache(self) -> bool:

conditions: List[bool] = []
conditions.append(False)
# Bias Add should be done in shared memory
# Bias Add should be performed in shared memory
conditions.append(with_bias)
return any(conditions) # Always set to False Currently

Expand Down Expand Up @@ -227,8 +227,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
Expand Down Expand Up @@ -444,16 +444,15 @@ def apply_config(
chunk=chunk,
)

# cache_write_required = self.check_require_cache()
cache_write_required = False
cache_write_required = self.check_require_cache()

# Define the main kernel using the generated configuration
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
# Grid and thread configuration for CUDA kernel
with T.Kernel(
Expand Down Expand Up @@ -667,8 +666,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
# Grid and thread configuration for CUDA kernel
with T.Kernel(
Expand Down Expand Up @@ -867,6 +866,8 @@ def apply_config(
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
assert in_dtype == "int4", "Only support int4 input"
assert accum_dtype == "int32", "Only support int32 accumulation"
with_bias = self.with_bias
assert not with_bias, "Currently do not support bias"
storage_dtype = "int8"

# Calculate the micro size per warp using a helper function
Expand All @@ -879,6 +880,8 @@ def apply_config(
# Define the shapes of matrices and shared memory buffers
A_shape = (M, K)
B_shape = (N, K)
Bias_shape = (N,)
C_shape = (M, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
Expand Down Expand Up @@ -918,7 +921,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, storage_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
# Grid and thread configuration for CUDA kernel
with T.Kernel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def main(
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def general_shared_dequant_matmul(
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def general_shared_dequant_matmul(
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
Expand Down
Loading
Loading