Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from af0b40 to 5a8b30
2 changes: 1 addition & 1 deletion bitblas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

BITBLAS_DEFAULT_CACHE_PATH = os.path.expanduser("~/.cache/bitblas")

MAX_ERROR_MESSAGE_LENGTH = 200
MAX_ERROR_MESSAGE_LENGTH = 500
80 changes: 51 additions & 29 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ def sch_shared_memory_prefetch_with_config(
"""

weight_transform_kind = config.intrin_info.weight_transform_kind
if weight_transform_kind == TransformKind.LDMatrixTransform and config.block_reduction_depth is not None:
if weight_transform_kind == TransformKind.LDMatrixTransform:
return self.sch_warp_memory_prefetch_with_config(func, config)

is_cross_thread_reduce = (
Expand Down Expand Up @@ -1826,8 +1826,7 @@ def check_dequantize_info(dequantize_info):
block_col_warps = config.block[1] // warp_col_tiles
stage = config.pipeline_stage
use_async = config.use_async
assert (config.block_reduction_depth is not None), "block_reduction_depth is required"
reduce_k = config.block_reduction_depth
reduce_k = config.block_reduction_depth if config.block_reduction_depth is not None else 1
chunk = config.rstep[0] // reduce_k

micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"]
Expand Down Expand Up @@ -2005,9 +2004,12 @@ def get_param_indices(
i0, i1, i2, i3 = sch.split(i, factors=i_factors)
j0, j1, j2, j3 = sch.split(j, factors=j_factors)
k0, k1 = sch.split(k, k_factors)
k0, kr = sch.split(k0, [None, reduce_k])
if reduce_k > 1:
k0, kr = sch.split(k0, [None, reduce_k])

sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3)
sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3)
else:
sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3)

block_idy = sch.fuse(i0, j0)
block_idx = sch.fuse(i1, j1)
Expand All @@ -2017,12 +2019,13 @@ def get_param_indices(
sch.bind(batch, "blockIdx.z")
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz)
sch.bind(thread_idy, "threadIdx.y")

# Put the thread binding after the shared memory prefetch
# Otherwise there's a axis missing bug behind tvm
sch.bind(kr, "threadIdx.z")
if reduce_k > 1:
thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz)
sch.bind(thread_idy, "threadIdx.y")
sch.bind(kr, "threadIdx.z")
else:
sch.bind(thread_idy, "threadIdx.y")
sch.bind(thread_idz, "threadIdx.z")

def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741
if not enable:
Expand Down Expand Up @@ -2054,13 +2057,20 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False):
ndim = len(sch.get(block_read).iter_vars)
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])

f_0, f_r, f_1, f_2, f_3, f_4 = sch.split(
fused, factors=[None, reduce_k, num_ty, num_tz, warp_size, vec_len])
if reduce_k > 1:
f_r, f_0, f_1, f_2, f_3, f_4 = sch.split(
fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len])
sch.bind(f_3, "threadIdx.x")
f_0 = f_1 = sch.fuse(f_0, f_1)
sch.bind(f_0, "threadIdx.y")
sch.bind(f_r, "threadIdx.z")
else:
f_0, f_1, f_2, f_3, f_4 = sch.split(
fused, factors=[num_ty, num_tz, None, warp_size, vec_len])
sch.bind(f_3, "threadIdx.x")
sch.bind(f_1, "threadIdx.z")
sch.bind(f_0, "threadIdx.y")

sch.bind(f_3, "threadIdx.x")
f_1 = f_2 = sch.fuse(f_1, f_2)
sch.bind(f_1, "threadIdx.y")
sch.bind(f_r, "threadIdx.z")
sch.vectorize(f_4)
sch.unroll(f_0)
sch.annotate(f_0, "pragma_unroll_explicit", False)
Expand All @@ -2085,7 +2095,8 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False):

# Put the thread binding after the shared memory prefetch
# Otherwise there's a axis missing bug behind tvm
sch.bind(kr, "threadIdx.z")
if reduce_k > 1:
sch.bind(kr, "threadIdx.z")
# create read cache to load matrix from shared memory to wmma fragments
A_mat = sch.cache_read(block_outer, 0, "warp")
sch.compute_at(A_mat, k1, preserve_unit_loops=True)
Expand Down Expand Up @@ -2127,18 +2138,24 @@ def get_idx():
ndim = len(sch.get(B_shared).iter_vars)
_ = sch.fuse(*sch.get_loops(B_shared)[-ndim:])

_bind_thread_based_with_block_reduce_on_config(
sch,
B_shared,
num_ty,
num_tz,
warp_size,
reduce_k,
)
if reduce_k > 1:
_bind_thread_based_with_block_reduce_on_config(
sch,
B_shared,
num_ty,
num_tz,
warp_size,
reduce_k,
)
else:
_bind_thread_based_on_config(sch, B_shared, num_ty, num_tz, warp_size)
return B_dequantized_mat

B_dequantized_mat = warp_memory_dequantize()

# Put the thread binding after the shared memory prefetch
# Otherwise there's a axis missing bug behind tvm
if reduce_k > 1:
sch.bind(kr, "threadIdx.z")
# create write cache to store matrix from wmma fragments to shared memory and global memory
if cache_write_required:
accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope)
Expand All @@ -2159,9 +2176,13 @@ def get_idx():
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:])
f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len])
f0, f1, f2 = sch.split(
fused,
factors=[
None, warp_size,
get_coalesced_veclen(sch.get(accumulator_shared_to_global))
])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
sch.unroll(f0)
Expand Down Expand Up @@ -2205,6 +2226,7 @@ def get_idx():
j0, j1 = sch.split(j, factors=[None, b_lr[1]])
sch.reorder(i0, j0, i1, j1)
_ = sch.blockize(i1)
vec_len = get_coalesced_veclen(sch.get(B_dequantized_mat))
sch.transform_block_layout(
B_dequantized_mat, lambda i, j: ((i * b_lr[1] + j) // vec_len,
(i * b_lr[1] + j) % vec_len))
Expand Down
16 changes: 8 additions & 8 deletions integration/BitNet/benchmark_inference_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)


def profile(model, input_data):
import time

Expand All @@ -34,26 +35,25 @@ def get_runtime(num_repeats=1):
times = get_runtime(num_repeats)
return np.mean(times)


def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
device_map='auto',
low_cpu_mem_usage=True,
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
with torch.no_grad():
model._post_process_weights()

benchmark_sets = [
(1, 1),
(128, 1),
(1, 2048)
]
model.quantize()
model = torch.compile(model)

benchmark_sets = [(1024, 1), (1, 2048)]
for batch_size, seq_len in benchmark_sets:
input_id = torch.ones(batch_size, seq_len).long().cuda()
latency = profile(model, input_id)
print(f"Batch size: {batch_size}, Seq len: {seq_len}, Latency: {latency}")


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion integration/pytorch/bitblas_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def error_raiser_bitblas(*args, **kwargs):
autogptq_bitblas_cuda = bitblas_import_exception

from bitblas.utils import auto_detect_nvidia_target # noqa: E402
from bitblas.ops.matmul import MatmulConfig, Matmul # noqa: E402
from bitblas.ops import MatmulConfig, Matmul # noqa: E402


class Linear(nn.Module):
Expand Down
Loading