Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from e1c5b0 to 71fe7c
207 changes: 207 additions & 0 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,210 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
)

return _warp_mma(A_local_buf, B_local_buf, C_local_buf)


class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):

def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = "int4"
b_dtype_abbrv = "int4"
accum_dtype = self.accum_dtype
accum_dtype_abbrv = accum_dtype
mma_prefix = "m16n8k32"

@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
'''
A[16, 32], B[16, 32], C[16, 16]
A_local_size -> 16
B_local_size -> 16
C_local_size -> 8
For each m16n8k32 inst
For A: m16k32 consume 16 int4 elements -> 8 A_local_size
For A: n8k32 consume 8 int4 elements -> 4 B_local_size
For C: m16n8 consume 4 int32 elements -> 4 C_local_size
'''

# A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)

# A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

# A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_a) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)

# A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_b) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

return _warp_mma(A_local_buf, B_local_buf, C_local_buf)


class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform):

def mma(self, A_local_buf, B_local_buf, C_local_buf):

warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = "int4"
b_dtype_abbrv = "int4"
accum_dtype = self.accum_dtype
accum_dtype_abbrv = "int32"
mma_prefix = "m16n8k32"

@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
'''
A[16, 32], B[16, 32], C[16, 16]
A_local_size -> 16
B_local_size -> 16
C_local_size -> 8
For each m16n8k32 inst
For A: m16k32 consume 16 int4 elements -> 8 A_local_size
For A: n8k32 consume 8 int4 elements -> 4 B_local_size
For C: m16n8 consume 4 int32 elements -> 4 C_local_size
'''

# A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)

# A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

# A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_a) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
T.bool(False),
)

# A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16]
T.ptx_mma(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
i * local_size_a + lift(local_size_b) // 2,
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
4 changes: 2 additions & 2 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]):
return micro_size_x, micro_size_y, micro_size_k


def make_swizzle_layout(shared_buf):
def make_swizzle_layout(shared_buf, is_smooth: bool = False):
dtype = shared_buf.dtype
shape = shared_buf.shape

can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
if is_smooth or not can_swizzle:
return T.Layout(shape, lambda *args: args)

def transform_func(i, j):
Expand Down
Loading
Loading