Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 5a8b30 to 1f0e1b
4 changes: 3 additions & 1 deletion bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class BaseScheduler(ABC):
@staticmethod
def Simplify(stmt: Union[PrimFunc, IRModule]):
if isinstance(stmt, PrimFunc):
return Simplify()(IRModule.from_expr(stmt))["main"]
mod = Simplify()(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify()(stmt)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class MatmulDequantizeScheduler(BaseScheduler):
group_size: int = -1
fast_decoding: bool = False
with_bias: bool = False
zeros_mode: Literal["original", "rescale", "quantized"] = "original",
zeros_mode: Literal["original", "rescale", "quantized"] = ("original",)

# Default Tile Related Params
block_M: int = 128
Expand Down Expand Up @@ -132,7 +132,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
group_size=self.group_size,
fast_decoding=self.fast_decoding,
with_bias=self.with_bias,
zeros_mode=self.zeros_mode)
zeros_mode=self.zeros_mode,
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
Expand Down Expand Up @@ -174,7 +175,7 @@ def with_default_config(self):
enable_rasterization=enable_rasterization,
)

def _apply_config_dequant_only(
def apply_config(
self,
block_M: Optional[int] = None,
block_N: Optional[int] = None,
Expand All @@ -191,25 +192,22 @@ def _apply_config_dequant_only(
assert threads is not None, "threads is required"
M, N, K = self.M, self.N, self.K
trans_A, trans_B = self.trans_A, self.trans_B

assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"

# check is dequantize only

def check_is_dequantize_only():
return not self.with_scaling

if not check_is_dequantize_only():
raise ValueError("Not a Dequantize Only Configuration")

in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
in_dtype, out_dtype, accum_dtype = (
self.in_dtype,
self.out_dtype,
self.accum_dtype,
)
fast_decoding = self.fast_decoding

num_bits = self.num_bits
storage_dtype = self.storage_dtype
source_format = self.source_format
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = 8 // num_bits
num_elems_per_byte = self.num_elems_per_byte

MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
Expand All @@ -221,6 +219,11 @@ def check_is_dequantize_only():

A_shape = (M, K)
B_shape = (N, K // storage_nbit * num_bits)
LUT_shape = (group_size, K // storage_nbit * num_bits)
Scale_shape = (N, K // group_size)
Zeros_shape = (N, K // group_size)
Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits)
Bias_shape = (N,)

A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
Expand All @@ -241,9 +244,14 @@ def check_is_dequantize_only():
assert func_name is not None, "lop3_intrin_info is not found"

@T.prim_func
def main(
def general_dequant_matmul(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
LUT: T.Buffer(LUT_shape, in_dtype),
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
Expand All @@ -270,7 +278,9 @@ def main(
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
index = (
i * threads * local_size_compressed + tx * local_size_compressed +
v)
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
Expand All @@ -280,15 +290,25 @@ def main(
func_name,
T.address_of(B_local[0]),
T.address_of(B_dequantize_local[0]),
dtype=in_dtype)
dtype=in_dtype,
)
else:
for v in T.serial(0, local_size):
B_dequantize_local[v] = self._decode_func(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
self._normal_dequant(
B_local,
B_dequantize_local,
Scale,
Zeros,
Qzeros,
local_size,
local_size_compressed,
bx,
tx,
k,
i,
block_N,
block_K,
threads,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
Expand All @@ -299,87 +319,7 @@ def main(

T.copy(C_local, C[by * block_M, bx * block_N])

return main

def _apply_config_with_scaling(
self,
block_M: Optional[int] = None,
block_N: Optional[int] = None,
block_K: Optional[int] = None,
num_stages: Optional[int] = None,
threads: Optional[int] = None,
# Enhance L2 Locality
enable_rasterization: bool = False,
):
raise NotImplementedError("Scaling Configuration is not implemented")

def _apply_config_with_scaling_zeros_original_or_rescale(
self,
block_M: Optional[int] = None,
block_N: Optional[int] = None,
block_K: Optional[int] = None,
num_stages: Optional[int] = None,
threads: Optional[int] = None,
# Enhance L2 Locality
enable_rasterization: bool = False,
):
raise NotImplementedError("Scaling and Zeros Original Configuration is not implemented")

def _apply_config_with_scaling_zeros_quantized(
self,
block_M: Optional[int] = None,
block_N: Optional[int] = None,
block_K: Optional[int] = None,
num_stages: Optional[int] = None,
threads: Optional[int] = None,
# Enhance L2 Locality
enable_rasterization: bool = False,
):
raise NotImplementedError("Scaling and Zeros Rescale Configuration is not implemented")

def apply_config(
self,
block_M: Optional[int] = None,
block_N: Optional[int] = None,
block_K: Optional[int] = None,
num_stages: Optional[int] = None,
threads: Optional[int] = None,
# Enhance L2 Locality
enable_rasterization: bool = False,
):
assert block_M is not None, "block_M is required"
assert block_N is not None, "block_N is required"
assert block_K is not None, "block_K is required"
assert num_stages is not None, "num_stages is required"
assert threads is not None, "threads is required"
trans_A, trans_B = self.trans_A, self.trans_B

assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"

with_scaling = self.with_scaling
with_zeros = self.with_zeros
zeros_mode = self.zeros_mode

args = [block_M, block_N, block_K, num_stages, threads, enable_rasterization]

dequant_prim_func = None

if not with_scaling:
dequant_prim_func = self._apply_config_dequant_only(*args)
elif not with_zeros:
dequant_prim_func = self._apply_config_with_scaling(*args)
elif zeros_mode in ["original", "rescale"]:
dequant_prim_func = self._apply_config_with_scaling_zeros_original_or_rescale(*args)
elif zeros_mode == "quantized":
dequant_prim_func = self._apply_config_with_scaling_zeros_quantized(*args)
else:
raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode))

if dequant_prim_func is None:
raise ValueError("Unsupported Configuration")

return self.maybe_simplify(dequant_prim_func)
return self.maybe_simplify(general_dequant_matmul)

@property
def _decode_func(self):
Expand Down Expand Up @@ -424,6 +364,125 @@ def naive_cast_dequant(x):

return dequant_func

# proxy method for macro expansion
def _normal_dequant(
self,
compressed_weight_local: T.Buffer,
dequant_weight_local: T.Buffer,
scale_buffer: T.Buffer,
zeros_buffer: T.Buffer,
qzeros_buffer: T.Buffer,
local_size: int,
local_size_compressed: int,
pid_n: T.Var,
tx: T.Var,
k: T.Var,
i: T.Var,
stride_n: int,
stride_k: int,
threads: int,
):
num_elems_per_byte = self.num_elems_per_byte
with_scaling = self.with_scaling
with_zeros = self.with_zeros
zeros_mode = self.zeros_mode
num_bits = self.num_bits
in_dtype = self.in_dtype
group_size = self.group_size
storage_dtype = self.storage_dtype
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))

@T.macro
def _normal_dequant_impl(
compressed_weight_local: T.Buffer,
dequant_weight_local: T.Buffer,
scale_buffer: T.Buffer,
zeros_buffer: T.Buffer,
qzeros_buffer: T.Buffer,
):
print("Normal Dequantize")
print("with_scaling", with_scaling)
print("with_zeros", with_zeros)
print("zeros_mode", zeros_mode)
print("num_bits", num_bits)
for v in T.serial(0, local_size):
index = (i * threads * local_size_compressed + tx * local_size_compressed + v)
vi = index // (stride_k // num_elems_per_byte)
vj = index % (stride_k // num_elems_per_byte)
if not with_scaling:
print("No Scaling")
dequant_weight_local[v] = self._decode_func(
num_bits,
compressed_weight_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
elif not with_zeros:
print("No Zeros")
# Scaling only
dequant_weight_local[v] = (
self._decode_func(
num_bits,
compressed_weight_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size])
elif zeros_mode == "original":
print("Original Zeros")
dequant_weight_local[v] = (self._decode_func(
num_bits,
compressed_weight_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
) - zeros_buffer[pid_n * stride_n + vi, (k * stride_k + vj) //
group_size]) * scale_buffer[pid_n * stride_n + vi,
(k * stride_k + vj) // group_size]
elif zeros_mode == "rescale":
print("rescale")
dequant_weight_local[v] = (
self._decode_func(
num_bits,
compressed_weight_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size] -
zeros_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size])
elif zeros_mode == "quantized":
print("Quantized Zeros")
dequant_qzeros = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
num_bits,
qzeros_buffer[
(k * stride_k + vj) // group_size,
(pid_n * stride_n + vi) // num_elems_per_byte,
],
(pid_n * stride_n + vi) % num_elems_per_byte,
dtype=storage_dtype,
)

dequant_weight_local[v] = (self._decode_func(
num_bits,
compressed_weight_local[v // num_elems_per_byte],
v % num_elems_per_byte,
zero=dequant_qzeros,
dtype=in_dtype,
)) * scale_buffer[pid_n * stride_n + vi, (k * stride_k + vj) // group_size]

return _normal_dequant_impl(
compressed_weight_local,
dequant_weight_local,
scale_buffer,
zeros_buffer,
qzeros_buffer,
)

@property
def num_elems_per_byte(self):
storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit()))
num_bits = self.num_bits
return storage_nbit // num_bits

def __post_init__(self):
# Add Config Validation
return
# Legalize group_size
if self.with_scaling and self.group_size == -1:
object.__setattr__(self, "group_size", self.K)
6 changes: 2 additions & 4 deletions bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,9 @@ def _build_default_module(self, target: Target):
assert (
len(scheduled_mod.get_global_vars()) == 1
), "The optimized module should only have one global variable for default schedule."
assert (
"main" in scheduled_mod
), "The optimized module should have a function named 'main' for default schedule."
global_symbol = scheduled_mod.get_global_vars()[0]
default_kernal_name = self.kernel_name_generator.generate()
func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
func = scheduled_mod[global_symbol].with_attr("global_symbol", default_kernal_name)
scheduled_ir_module = tvm.IRModule({default_kernal_name: func})
self._update_optimized_mod(scheduled_ir_module)
except Exception as apply_schedule_error:
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ cmake .. && make -j && cd ../../..

echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc
echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc

echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc
source ~/.bashrc
Loading
Loading