Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c4853ec
Refactor Simplify function to handle multiple functions in IRModule
LeiWang1999 Oct 16, 2024
9a21acf
Update submodule commit reference
LeiWang1999 Oct 17, 2024
f8d046b
Add CUDA_DEVICE_ORDER environment variable to bashrc
LeiWang1999 Oct 17, 2024
c1371dd
test fix
LeiWang1999 Oct 17, 2024
416cad2
lint fix
LeiWang1999 Oct 17, 2024
9209d1e
Refactor test_general_matmul_bf16.py to use bitblas.testing.main()
LeiWang1999 Oct 17, 2024
1cf7570
Update submodule commit reference
LeiWang1999 Oct 17, 2024
5fec040
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
4e1a0d2
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
fa85f8c
Update submodule commit reference
LeiWang1999 Oct 19, 2024
429d5b5
Update submodule commit reference
LeiWang1999 Oct 19, 2024
4003509
Update submodule commit reference
LeiWang1999 Oct 20, 2024
1d86582
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 20, 2024
df3af0d
Update submodule commit reference
LeiWang1999 Oct 28, 2024
1f1e027
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 28, 2024
732dda6
Update submodule commit reference
LeiWang1999 Oct 29, 2024
ebffbfa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 29, 2024
ff227fa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 4, 2024
ac62936
[Dev] Update subproject commit for TVM
LeiWang1999 Nov 7, 2024
a7a239c
ignore profiler directories.
LeiWang1999 Nov 7, 2024
dcedbde
MFMA Support
LeiWang1999 Nov 7, 2024
e0b36f5
lint fix
LeiWang1999 Nov 7, 2024
fe668f9
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 7, 2024
3579c6b
MFMA Fixed.
LeiWang1999 Nov 8, 2024
e60ccd9
merge upstream
LeiWang1999 Nov 8, 2024
d4df21c
update
LeiWang1999 Nov 8, 2024
e4ff7f3
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 8, 2024
57e3cf9
Fix MFMA Layout Related issue
LeiWang1999 Nov 8, 2024
c3398f5
lint fix
LeiWang1999 Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 180359 to 8847ba
37 changes: 15 additions & 22 deletions bitblas/tl/mfma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Licensed under the MIT License.

import tvm.tl.language as T

from typing import Tuple
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
from .utils import (
mfma_store_index_map,)
Expand Down Expand Up @@ -142,12 +143,16 @@ def get_ldmatrix_index_map(self, is_b=False):

return index_map, reverse_index_map

def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
def extract_thread_binding(self, thread_id) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
return thread_id % WARP_SIZE, (thread_id // WARP_SIZE) % block_col_warps, (
thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps

def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
warp_row_tiles = self.warp_row_tiles
warp_cols = self.warp_cols
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
Expand All @@ -164,17 +169,16 @@ def _warp_ldmatrix_a(
thread_bindings,
rk=0,
):
tx = thread_bindings % WARP_SIZE
tz = (thread_bindings // (WARP_SIZE * block_col_warps)) % block_row_warps
tx, _, tz = self.extract_thread_binding(thread_bindings)
if is_transposed:
for i in T.serial(warp_cols):
for i in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * micro_size_k,
tz * warp_row_tiles + i * micro_size_x)
A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col]
else:
for i in T.serial(warp_cols):
for i in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (tz * warp_row_tiles + i * micro_size_x,
Expand All @@ -184,9 +188,6 @@ def _warp_ldmatrix_a(
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)

def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):

WARP_SIZE = self.WARP_SIZE
block_col_warps = self.block_col_warps
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
Expand All @@ -205,8 +206,7 @@ def _warp_ldmatrix_b(
thread_bindings,
rk=0,
):
tx = thread_bindings % WARP_SIZE
ty = (thread_bindings // WARP_SIZE) % block_col_warps
tx, ty, _ = self.extract_thread_binding(thread_bindings)

if is_transposed:
for j in T.serial(warp_cols):
Expand Down Expand Up @@ -263,7 +263,6 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)

def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None):
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_rows = self.warp_rows
Expand All @@ -281,23 +280,17 @@ def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None):
# equal to the warp_size
@T.macro
def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings):
tx = thread_bindings % WARP_SIZE
ty = (thread_bindings // WARP_SIZE) % block_row_warps
tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps

tx, ty, tz = self.extract_thread_binding(thread_bindings)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[ty * warp_rows + i, tz * warp_cols + j, row,
C_buf[tz * warp_rows + i, ty * warp_cols + j, row,
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out +
local_id]

@T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings):
tx = thread_bindings % WARP_SIZE
ty = (thread_bindings // WARP_SIZE) % block_row_warps
tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps

tx, ty, tz = self.extract_thread_binding(thread_bindings)
for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.serial(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id))
Expand Down
Loading