Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c4853ec
Refactor Simplify function to handle multiple functions in IRModule
LeiWang1999 Oct 16, 2024
9a21acf
Update submodule commit reference
LeiWang1999 Oct 17, 2024
f8d046b
Add CUDA_DEVICE_ORDER environment variable to bashrc
LeiWang1999 Oct 17, 2024
c1371dd
test fix
LeiWang1999 Oct 17, 2024
416cad2
lint fix
LeiWang1999 Oct 17, 2024
9209d1e
Refactor test_general_matmul_bf16.py to use bitblas.testing.main()
LeiWang1999 Oct 17, 2024
1cf7570
Update submodule commit reference
LeiWang1999 Oct 17, 2024
5fec040
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
4e1a0d2
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
fa85f8c
Update submodule commit reference
LeiWang1999 Oct 19, 2024
429d5b5
Update submodule commit reference
LeiWang1999 Oct 19, 2024
4003509
Update submodule commit reference
LeiWang1999 Oct 20, 2024
1d86582
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 20, 2024
df3af0d
Update submodule commit reference
LeiWang1999 Oct 28, 2024
1f1e027
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 28, 2024
732dda6
Update submodule commit reference
LeiWang1999 Oct 29, 2024
ebffbfa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 29, 2024
ff227fa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 4, 2024
ac62936
[Dev] Update subproject commit for TVM
LeiWang1999 Nov 7, 2024
a7a239c
ignore profiler directories.
LeiWang1999 Nov 7, 2024
dcedbde
MFMA Support
LeiWang1999 Nov 7, 2024
e0b36f5
lint fix
LeiWang1999 Nov 7, 2024
fe668f9
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 7, 2024
3579c6b
MFMA Fixed.
LeiWang1999 Nov 8, 2024
e60ccd9
merge upstream
LeiWang1999 Nov 8, 2024
d4df21c
update
LeiWang1999 Nov 8, 2024
e4ff7f3
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 8, 2024
57e3cf9
Fix MFMA Layout Related issue
LeiWang1999 Nov 8, 2024
c3398f5
lint fix
LeiWang1999 Nov 8, 2024
ddd0219
amd hip update
LeiWang1999 Nov 8, 2024
754294f
Block GEMM Example
LeiWang1999 Nov 13, 2024
e041d91
fix amd
YangWang92 Nov 15, 2024
2910b3c
mi300 update
YangWang92 Nov 15, 2024
cf934c1
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
YangWang92 Nov 15, 2024
baf4abd
fix and enhance
YangWang92 Nov 15, 2024
c3605e8
lintfix
LeiWang1999 Nov 15, 2024
8ee2c63
enhance amd installation
LeiWang1999 Nov 15, 2024
451318a
update submodule
LeiWang1999 Nov 15, 2024
14672df
update tvm
LeiWang1999 Nov 17, 2024
ff7f6d8
implement fragement
LeiWang1999 Nov 17, 2024
cccbe68
test update
LeiWang1999 Nov 17, 2024
9edecfd
Optimize MFMA Layout
LeiWang1999 Nov 19, 2024
7b5d4d5
lint fix
LeiWang1999 Nov 19, 2024
41c2fb7
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 4a2e00 to a12155
7 changes: 7 additions & 0 deletions benchmark/tilelang/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 2048 --n 2048 --k 2048 2>&1 | tee run_gemm_tilelang_2048_2048_2048.log
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 4096 --n 4096 --k 4096 2>&1 | tee run_gemm_tilelang_4096_4096_4096.log
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 8192 2>&1 | tee run_gemm_tilelang_8192_8192_8192.log
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 16384 --n 16384 --k 16384 2>&1 | tee run_gemm_tilelang_16384_16384_16384.log
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 1024 2>&1 | tee run_gemm_tilelang_8192_8192_1024.log
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 2048 2>&1 | tee run_gemm_tilelang_8192_8192_2048.log
python /home/aiscuser/lei/BitBLAS/benchmark/tilelang/benchmark_tilelang_matmul.py --m 8192 --n 8192 --k 4096 2>&1 | tee run_gemm_tilelang_8192_8192_4096.log
91 changes: 91 additions & 0 deletions benchmark/tilelang/benchmark_tilelang_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
import itertools


def ref_program(A, B):
return A @ B.T


def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128, 256]
num_stages = [0, 1, 2, 3, 4]
thread_num = [128, 256]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasteration))

configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'thread_num': c[4],
'enable_rasteration': c[5]
} for c in _configs]
return configs


def matmul(M, N, K):

@autotune(
configs=get_configs(),
keys=['block_M', 'block_N', 'block_K', 'num_stages', 'thread_num'],
warmup=3,
rep=5)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="tvm",
target="hip")
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None):
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(A: T.Buffer((M, K), dtype), B: T.Buffer((N, K), dtype), C: T.Buffer((M, N),
dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.use_swizzle(panel_size=10, enable=enable_rasteration)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])

return main

return kernel()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=8192, help='M')
parser.add_argument('--n', type=int, default=8192, help='N')
parser.add_argument('--k', type=int, default=8192, help='K')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
total_flops = 2 * M * N * K
best_latency, best_config, ref_latency = matmul(M, N, K)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")
156 changes: 156 additions & 0 deletions benchmark/tilelang/benchmark_tilelang_mha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import argparse
import torch
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
from functools import partial
import itertools


def get_configs():
block_M = [32, 64, 128]
block_N = [32, 64, 128]
num_stages = [0, 1, 2]
thread_num = [128, 256]
_configs = list(itertools.product(block_M, block_N, num_stages, thread_num))

configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'thread_num': c[3]
} for c in _configs]
return configs


def ref_program(Q, K, Vt, casual):
import torch.nn.functional as F
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if casual:
mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1)), diagonal=1).bool().cuda()
scores.masked_fill_(mask, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bdhk->bqhd', attention_weights, Vt)
return output


def flashattn(batch, heads, seq_len, dim, is_casual):

@autotune(
configs=get_configs(),
keys=['block_M', 'block_N', 'num_stages', 'thread_num'],
warmup=10,
rep=5)
@jit(
out_idx=[3],
supply_type=tl.TensorSupplyType.Normal,
ref_prog=partial(ref_program, casual=is_casual),
rtol=0.01,
atol=0.01,
target="hip")
def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
vt_shape = [batch, dim, heads, seq_len]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
Vt: T.Buffer(vt_shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
Vt_shared = T.alloc_shared([dim, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_shared = T.alloc_shared([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

# T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.copy(Vt[bz, :, by, k * block_N:(k + 1) * block_N], Vt_shared)
for i, j in T.Parallel(block_M, dim):
acc_s[i, j] *= scale
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
# T.copy(acc_s, acc_s_cast)
T.copy(acc_s, acc_s_shared)
T.copy(acc_s_shared, acc_s_cast)
T.gemm(
acc_s_cast,
Vt_shared,
acc_o,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])

return main

return kernel()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size')
parser.add_argument('--h', type=int, default=12, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=2048, help='Context size')
parser.add_argument('--d_head', type=int, default=128, help='Head dimension')
parser.add_argument('--casual', type=bool, default=False, help='Casual flag')
args = parser.parse_args()
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
casual = args.casual
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if casual:
total_flops *= 0.5

best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")
21 changes: 7 additions & 14 deletions bitblas/tl/mfma_layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from tvm.runtime import convert


Expand Down Expand Up @@ -62,19 +61,13 @@ def shared_16x16_to_local_64x4_layout_B(i, j):
return thread_id, local


def thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return j, i


def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id):
# This is a hacky implementation to simulate the performance
is_smooth = os.environ.get("TILE_LANG_SMOOTH_LAYOUT") == "1"
print(is_smooth)
if is_smooth:
return thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id)

def thread_id_shared_access_64x4_to_16x16_layout_C_m_n(thread_id, local_id):
i = local_id + (thread_id // 16) * 4
j = thread_id % 16
return i, j


def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return i, j
Loading
Loading