Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 5ec617 to 430758
4 changes: 4 additions & 0 deletions bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def __init__(self) -> None:
# Config for block reduction
self.block_reduction_depth = None # type: int

# TL Specific
# Split-K factor for SM waste optimization
self.split_k_factor: int = 1

# Experimental
self._raxis_order = []
self._step = []
Expand Down
4 changes: 2 additions & 2 deletions bitblas/ops/general_flashatten/tilelang/flashatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply_config(
block_N=64,
num_stages=2,
threads=128,
enable_rasterization=False,
enable_rasterization: bool = False,
):
batch, heads, seq_len, dim = self.batch, self.heads, self.seq_len, self.dim
trans_K = self.trans_K
Expand Down Expand Up @@ -185,7 +185,7 @@ def flashatten_blocked(
num_stages=2,
threads=128,
is_causal=False,
enable_rasterization=False, # Enhance L2 Locality
enable_rasterization: bool = False, # Enhance L2 Locality
):
Q_shape = (batch, seq_len, heads, dim) if not trans_Q else (batch, dim, heads, seq_len)
K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len)
Expand Down
2 changes: 2 additions & 0 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
return self.serialize_hints_to_configs(roller_hints)

def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10):
if arch is None:
arch = self.arch
return self.get_roller_configs(arch, topk)

# check if required shared memory cache
Expand Down
22 changes: 11 additions & 11 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
Expand Down Expand Up @@ -578,7 +578,7 @@ def apply_config(
warp_col_tiles=32,
chunk=16,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):

M = self.maybe_dynamic(self.M, "m")
Expand Down Expand Up @@ -706,8 +706,8 @@ def main(
micro_size_x,
micro_size_k,
):
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x),
ko * (block_K // micro_size_k), ii, kk]
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i,
ko * (block_K // micro_size_k) + k, ii, kk]
else:
T.copy(A[by * block_M, ko * block_K], A_shared)

Expand Down Expand Up @@ -850,7 +850,7 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def apply_config(
warp_col_tiles=32,
chunk=16,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):

M = self.maybe_dynamic(self.M, "m")
Expand Down Expand Up @@ -1183,8 +1183,8 @@ def main(
micro_size_x,
micro_size_k,
):
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x),
ko * (block_K // micro_size_k), ii, kk]
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i,
ko * (block_K // micro_size_k) + k, ii, kk]
else:
T.copy(A[by * block_M, ko * block_K], A_shared)

Expand Down Expand Up @@ -1264,7 +1264,7 @@ def matmul_blocked(
accum_dtype="float16",
num_stages=2,
threads=128,
enable_rasterization=False, # Enhance L2 Locality
enable_rasterization: bool = False, # Enhance L2 Locality
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def matmul_macro_tensorcore(
warp_col_tiles,
chunk,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert trans_A is False, "Currently only support Matrix A is not transposed"
assert trans_B is True, "Currently only support Matrix B is transposed"
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix(
warp_col_tiles,
chunk,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert trans_A is False, "Currently only support Matrix A is not transposed"
assert trans_B is True, "Currently only support Matrix B is transposed"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
return self.serialize_hints_to_configs(roller_hints)

def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10):
if arch is None:
arch = self.arch
return self.get_roller_configs(arch, topk)

# check if required shared memory cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
return self.serialize_hints_to_configs(roller_hints)

def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10):
if arch is None:
arch = self.arch
return self.get_roller_configs(arch, topk)

# check if required shared memory cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bitblas.ops.general_matmul.tirscript import (
matmul_dequantize_select_implementation,)
from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitter)
from bitblas.base.arch import TileDevice
from bitblas.base.arch import TileDevice, is_cuda_arch
from bitblas.base.roller.hint import Hint
from bitblas.base.roller.rasterization import NoRasterization
from bitblas.base.utils import get_roller_hints_from_func
Expand All @@ -37,8 +37,9 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler):
chunk: int = 32 # Usually determines the K-dimension split size

# Other Optimization Parameters
num_stages: int = 2
num_stages: int = 0
enable_rasterization: bool = False # Enhance L2 Locality
split_k_factor: int = 1 # Split-K factor for SM waste optimization

class TLHint(BaseTLHint):

Expand Down Expand Up @@ -76,6 +77,7 @@ def from_roller_hint(cls, hint: Hint):
tl_hint.chunk = chunk
tl_hint.num_stages = num_stages
tl_hint.enable_rasterization = enable_rasterization
tl_hint.split_k_factor = hint.split_k_factor

return tl_hint

Expand All @@ -88,6 +90,7 @@ def get_config_params(self):
"chunk": self.chunk,
"num_stages": self.num_stages,
"enable_rasterization": self.enable_rasterization,
"split_k_factor": self.split_k_factor,
}

def __repr__(self):
Expand All @@ -99,7 +102,8 @@ def __repr__(self):
f"block_K={self.chunk},"
f"threads={self.block_row_warps * self.block_col_warps * warp_size},"
f"num_stages={self.num_stages},"
f"enable_rasterization={self.enable_rasterization}"
f"enable_rasterization={self.enable_rasterization},"
f"split_k_factor={self.split_k_factor}"
"}")

def get_hint_type(self) -> str:
Expand All @@ -108,7 +112,61 @@ def get_hint_type(self) -> str:
def serialize_hints_to_configs(self, hints: List[Hint]):
configs = []
for hint in hints:
# Extract static shape dimensions for matrix multiplication
M, N, K = self.M, self.N, self.K

# Determine if the shapes are statically defined (not dynamic)
is_static_shape = isinstance(M, int) and isinstance(N, int) and isinstance(K, int)

# Check if the architecture is CUDA-based
arch_is_cuda = is_cuda_arch(self.arch)

# If the architecture is CUDA and we have a static shape, proceed with optimization
if arch_is_cuda and is_static_shape:
sm_waste_threshold = 5e-2 # Allow at most 5% SM waste
num_sms = self.arch.compute_max_core # Get the maximum number of streaming multiprocessors

# Compute block sizes based on the configuration
block_M = hint.block[0] # Block size in the M dimension
block_N = hint.block[1] # Block size in the N dimension
block_K = hint.rstep[0] # Block size in the K dimension

# Calculate the grid dimensions in M and N directions
grid_m = M // block_M
grid_n = N // block_N
total_grids = grid_m * grid_n # Total number of grids

# Initialize the split-k factor (used to distribute K-dimension work across blocks)
split_k_factor = 1

# Optimize the split-k factor to minimize SM waste
while True:
# Total grids after applying split-k
total_grids_split_k = total_grids * split_k_factor

# Calculate the waste in SMs after split-k distribution
waste_sm_splitk = total_grids_split_k - (total_grids_split_k //
num_sms) * num_sms
waste_sm_splitk_ratio = waste_sm_splitk / total_grids_split_k

# If the SM waste ratio is within the allowed threshold, stop optimization
if waste_sm_splitk_ratio <= sm_waste_threshold:
break

# Double the split-k factor and check if the resulting K-dimension size is too large
expand_split_k = split_k_factor * 2
if expand_split_k * block_K >= K:
break

# Update the split-k factor for the next iteration
split_k_factor = expand_split_k

# Note: The optimized split_k_factor can be stored or applied to the config if needed
hint.split_k_factor = split_k_factor

# Convert the hint to a configuration object using the TLHint mapping
config = self.TLHint.from_roller_hint(hint)

configs.append(config)
return configs

Expand All @@ -123,6 +181,7 @@ def with_default_config(self):

num_stages = getattr(self, "num_stages", 2)
enable_rasterization = getattr(self, "enable_rasterization", False)
split_k_factor = getattr(self, "split_k_factor", 1)

return self.apply_config(
block_row_warps=block_row_warps,
Expand All @@ -132,6 +191,7 @@ def with_default_config(self):
chunk=chunk,
num_stages=num_stages,
enable_rasterization=enable_rasterization,
split_k_factor=split_k_factor,
)

def apply_config(
Expand All @@ -142,7 +202,8 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
split_k_factor: Optional[int] = None,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
Expand Down Expand Up @@ -204,6 +265,8 @@ def apply_config(
Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits)
Bias_shape = (N,)

splitK = K // split_k_factor

A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
Expand Down Expand Up @@ -253,7 +316,15 @@ def apply_config(
chunk=chunk,
)

cache_write_required = self.check_require_cache()
enable_split_k = split_k_factor > 1

def check_require_cache():
conditions = [False]
conditions.append(self.check_require_cache())
conditions.append(enable_split_k)
return any(conditions)

cache_write_required = check_require_cache()

@T.prim_func
def general_dequant_matmul(
Expand All @@ -267,7 +338,8 @@ def general_dequant_matmul(
Bias: T.Buffer(Bias_shape, in_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Expand Down Expand Up @@ -296,10 +368,13 @@ def general_dequant_matmul(

T.clear(C_frag)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages):

T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared)
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
T.copy(
B[bx * block_N,
bz * (splitK // num_elems_per_byte) + ko * block_K // num_elems_per_byte],
B_shared)

for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
Expand Down Expand Up @@ -359,6 +434,7 @@ def general_dequant_matmul(

# Matrix multiplication on fragments
mma_emitter.mma(A_frag, B_frag, C_frag)

if cache_write_required:
# Store the result back to C shared memory
mma_emitter.stmatrix(
Expand All @@ -377,13 +453,24 @@ def general_dequant_matmul(
] += Bias[bx * block_N + j]

# Store results from shared memory to global memory
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
if enable_split_k:
for i, j in T.Parallel(block_M, block_N // 2):
T.atomic_addx2(
C[by * block_M + i, bx * block_N + j * 2], C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
])
else:
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]

else:
# Store the result back to C global memory
mma_emitter.stmatrix(
Expand Down Expand Up @@ -463,14 +550,17 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
split_k_factor: Optional[int] = None,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
assert warp_row_tiles is not None, "warp_row_tiles is required"
assert warp_col_tiles is not None, "warp_col_tiles is required"
assert chunk is not None, "chunk is required"
assert num_stages is not None, "num_stages is required"
# unused variable
split_k_factor = split_k_factor

M = self.maybe_dynamic(self.M, "m")
N, K = self.N, self.K
Expand Down
Loading
Loading