Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
72b9740
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 23, 2024
5b65979
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 27, 2024
d9bd479
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 29, 2024
99515cb
buf fix for matrix support
LeiWang1999 Aug 29, 2024
14406ef
lint fix
LeiWang1999 Aug 29, 2024
d30ec4f
dispatch tensor core based on shapes
LeiWang1999 Aug 29, 2024
fde4029
update install commands
LeiWang1999 Aug 30, 2024
6a04749
import scripts
LeiWang1999 Aug 31, 2024
9d90c40
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
LeiWang1999 Aug 31, 2024
9ef14e9
remove shared mem hack
LeiWang1999 Sep 1, 2024
63f363e
revert change for swizzling
LeiWang1999 Sep 1, 2024
b29c66c
bug fix
LeiWang1999 Sep 1, 2024
4643dd9
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
LeiWang1999 Sep 1, 2024
28beb13
tl examples
LeiWang1999 Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def inverse_permutation(i, j, ii, jj):
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down
8 changes: 4 additions & 4 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down Expand Up @@ -1075,7 +1075,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down Expand Up @@ -1675,7 +1675,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down Expand Up @@ -2194,7 +2194,7 @@ def get_idx():
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-5],
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
Expand Down
164 changes: 164 additions & 0 deletions testing/python/tilelang/test_tilelang_dequantize_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
from bitblas import tvm as tvm
from tvm import tl
from bitblas.quantization import _tir_packed_to_unsigned_convert


def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)

import tvm.tl.language as T

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment([8], storage_dtype, "local")
B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local")
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, dtypeAB
)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)

for i in T.serial(
block_N * block_K // num_elems_per_byte // (threads * 16)
):
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
for v in T.vectorized(0, 16):
vi = (i * threads * 16 + t * 16 + v) // (
block_K // num_elems_per_byte
)
vj = (i * threads * 16 + t * 16 + v) % (
block_K // num_elems_per_byte
)
B_shared[vi, vj] = B[
bx * block_N + vi,
k * block_K // num_elems_per_byte + vj,
]

for i in T.serial(
block_N * block_K // num_elems_per_byte // (threads * 4)
):
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
for v in T.vectorized(0, 4):
vi = (i * threads * 4 + t * 4 + v) // (
block_K // num_elems_per_byte
)
vj = (i * threads * 4 + t * 4 + v) % (
block_K // num_elems_per_byte
)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, 8):
B_dequantize_local[
v
] = _tir_packed_to_unsigned_convert("int", 8)(
num_bits,
B_local[v // 2],
v % 2,
dtype=dtypeAB,
)
for v in T.vectorized(0, 8):
vi = (i * threads * 8 + t * 8 + v) // (block_K)
vj = (i * threads * 8 + t * 8 + v) % (block_K)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm(
M,
N,
K,
dtypeAB,
dtypeC,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
dtypeAB,
dtypeC,
dtypeAccum,
num_stages,
num_threads,
)
print(program)

mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)

out = mod.run_once()

print(f"output is {out}")

with open("debug/kernel.cu", "w") as f:
f.write(mod.mod.imported_modules[0].get_source())

def ref_program(A, qB):
import torch

B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half)
.to(torch.half)
.to(A.device)
)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(
torch.half
)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C

mod.assert_allclose(ref_program)


def test_run_dequantize_gemm():
run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128)


if __name__ == "__main__":
bitblas.testing.main()
173 changes: 173 additions & 0 deletions testing/python/tilelang/test_tilelang_flash_atten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import argparse
import torch
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
from functools import partial
import itertools


def get_configs():
block_M = [32, 64, 128]
block_N = [32, 64, 128]
num_stages = [1, 2]
thread_num = [128, 256]
_configs = list(itertools.product(block_M, block_N, num_stages, thread_num))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"num_stages": c[2],
"thread_num": c[3],
}
for c in _configs
]
return configs


def ref_program(Q, K, V, casual):
from flash_attn.flash_attn_interface import flash_attn_func

return flash_attn_func(Q, K, V, causal=casual)


def flashattn(batch, heads, seq_len, dim, is_casual):

@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "thread_num"],
warmup=10,
rep=5,
)
@jit(
out_idx=[3],
supply_type=tl.TensorSupplyType.Normal,
ref_prog=partial(ref_program, casual=is_casual),
rtol=0.01,
atol=0.01,
)
def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num
) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout(
{Q_shared: tl.layout.make_swizzled_layout(Q_shared)}
)
T.copy(
Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared
)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(Q_shared, Q_local)
for i, j in T.Parallel(block_M, dim):
Q_local[i, j] *= scale
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N)
if is_casual
else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared
)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j,
0,
-T.infinity(acc_s.dtype),
)
else:
T.clear(acc_s)
T.gemm(
Q_local,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(
V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared
)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(
scores_max_prev[i] - scores_max[i]
)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
T.copy(acc_s, acc_s_cast)
T.gemm(
acc_s_cast,
V_shared,
acc_o,
policy=T.GemmWarpPolicy.FullRow,
)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(
acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]
)

return main

return kernel()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument("--h", type=int, default=12, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
parser.add_argument(
"--d_head", type=int, default=256, help="Head dimension"
)
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
args = parser.parse_args()
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
casual = args.casual
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if casual:
total_flops *= 0.5

best_latency, best_config, ref_latency = flashattn(
BATCH, H, N_CTX, D_HEAD, casual
)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")