Skip to content
4 changes: 4 additions & 0 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ def sch_dequantize_in_register_with_config(
V
compute
"""
weight_transform_kind = config.intrin_info.weight_transform_kind
if weight_transform_kind == TransformKind.LDMatrixTransform:
return self.sch_warp_memory_prefetch_with_config(func, config)

from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel
get_mma_intrin_group,)
from .intrin import get_lop3_intrin_group
Expand Down
23 changes: 16 additions & 7 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,14 +383,20 @@ def with_default_config(self):

def apply_config(
self,
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=32,
warp_col_tiles=32,
chunk=16,
num_stages=2,
block_row_warps: Optional[int] = None,
block_col_warps: Optional[int] = None,
warp_row_tiles: Optional[int] = None,
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
assert warp_row_tiles is not None, "warp_row_tiles is required"
assert warp_col_tiles is not None, "warp_col_tiles is required"
assert chunk is not None, "chunk is required"
assert num_stages is not None, "num_stages is required"

M, N, K = self.M, self.N, self.K
trans_A, trans_B = self.trans_A, self.trans_B
Expand Down Expand Up @@ -534,6 +540,9 @@ def __post_init__(self):
@dataclass
class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler):

# Ladder Transform Config
weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform

def apply_config(
self,
block_row_warps=2,
Expand Down Expand Up @@ -604,7 +613,7 @@ def apply_config(
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
transform_kind_b=TransformKind.LDMatrixTransform,
transform_kind_b=self.weight_transform_kind,
)

# Define the main kernel using the generated configuration
Expand Down
8 changes: 8 additions & 0 deletions bitblas/ops/general_matmul/tilelang/dequantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
MatmulDequantizeScheduler, # noqa: F401
)

from .finegrained_primitive_tensorcore import (
MatmulDequantizeFineGrainedScheduler, # noqa: F401
)

from .ladder_weight_transform_tensorcore import (
MatmulDequantizeWeightPropagationScheduler, # noqa: F401
)

from bitblas.ops.common import TransformKind
from typing import Union

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def general_dequant_matmul(
Zeros,
Qzeros,
local_size,
local_size_compressed,
bx,
tx,
k,
Expand Down Expand Up @@ -384,7 +383,6 @@ def _normal_dequant(
zeros_buffer: T.Buffer,
qzeros_buffer: T.Buffer,
local_size: int,
local_size_compressed: int,
pid_n: T.Var,
tx: T.Var,
k: T.Var,
Expand Down Expand Up @@ -413,9 +411,9 @@ def _normal_dequant_impl(
qzeros_buffer: T.Buffer,
):
for v in T.serial(0, local_size):
index = (i * threads * local_size_compressed + tx * local_size_compressed + v)
vi = index // (stride_k // num_elems_per_byte)
vj = index % (stride_k // num_elems_per_byte)
index = (i * threads * local_size + tx * local_size + v)
vi = index // stride_k
vj = index % stride_k
if not with_scaling:
dequant_weight_local[v] = self._decode_func(
num_bits,
Expand Down Expand Up @@ -486,12 +484,9 @@ def _normal_fast_dequant(
qzeros_buffer: T.Buffer,
func_name: str,
pid_n: T.Var,
tx: T.Var,
k: T.Var,
i: T.Var,
stride_n: int,
stride_k: int,
threads: int,
):
num_elems_per_byte = self.num_elems_per_byte
with_scaling = self.with_scaling
Expand Down
Loading
Loading