Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bitblas/base/common_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable, List

from tvm import tir

from bitblas.utils import retrieve_func_from_module
from .analysis import BlockInfo


Expand Down Expand Up @@ -74,7 +74,7 @@ def get_output_blocks(
"""

# collect arguments buffer
func = sch.mod["main"]
func = retrieve_func_from_module(sch.mod)
args = list(func.buffer_map.values())

output_blocks = []
Expand Down
21 changes: 18 additions & 3 deletions bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
import tempfile
import itertools
from tvm.ir.supply import GlobalVarSupply
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils import (
tensor_replace_dp4a,
tensor_remove_make_int4,
tensor_remove_make_int2,
retrieve_func_from_module,
)
from bitblas.utils.tensor_adapter import (
np_float2np_bf16,)
import logging
Expand Down Expand Up @@ -58,18 +63,28 @@ def __init__(self, config, sch, mod: Module):
self.time_evaluator = None

def profile(self, data_distribution="uniform"):
func = self.sch.mod["main"]
func = retrieve_func_from_module(self.sch.mod)
device = self.config.arch.device
profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution)
latency = self.time_evaluator(*profile_tensors).mean * 1e3
return latency


def get_roller_hints_from_func(func: tir.PrimFunc,
def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
arch: TileDevice,
topk: int = 10,
tensorcore_only: bool = False,
allow_gemv: bool = False) -> Optional[List[Hint]]:
func = None
if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module
elif isinstance(func_or_module, IRModule):
func = retrieve_func_from_module(func_or_module)
else:
raise ValueError("Not supported type: ", type(func_or_module))

assert func is not None, "The function should not be None"

if tensorcore_only:
try:
tensorized_func, tags = get_tensorized_func_and_tags(
Expand Down
27 changes: 15 additions & 12 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,8 @@ def is_dequantize(block: BlockRV) -> bool:
has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads)
if not has_uint_input:
return False
if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype):
return False
return True
return not (len(block_stmt.writes) != 1 or
"float" not in str(block_stmt.writes[0].buffer.dtype))

dequantize_blocks = [block for block in blocks if is_dequantize(block)]
return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None
Expand Down Expand Up @@ -552,9 +551,7 @@ def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool:
len(
collect_block_iter_vars_used_in_access_region(block_stmt,
block_stmt.writes[0].region)) > 0)
if not all(conditions):
return False
return True
return all(conditions)

# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
def check_sm_version(arch: str) -> int:
Expand Down Expand Up @@ -677,14 +674,20 @@ def check_last_trait(region: List[Range]):
block_stmt = sch.get(main_block)

# 16 for 16 bits tensor core while 32 for 8bits tensorcore.
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
minimal_tensorize_spatial_threshold = 16
minimal_tensorize_reduce_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
if isinstance(extent, tir.expr.IntImm) and (extent.value < (1 if allow_gemv else
minimal_tensorize_threshold)):
return func, None
for item_var in block_stmt.iter_vars[2:]:
for item_var in block_stmt.iter_vars[1:]:
extent = item_var.dom.extent
iter_type = item_var.iter_type

if iter_type is IterVar.DataPar:
minimal_tensorize_threshold = minimal_tensorize_spatial_threshold
elif iter_type is IterVar.CommReduce:
minimal_tensorize_threshold = minimal_tensorize_reduce_threshold
else:
raise ValueError(f"Unknown IterVar type {iter_type}")

if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold):
return func, None
tags = analysis_tensorcore_tags(sch, main_block, target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down Expand Up @@ -340,7 +340,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
51 changes: 49 additions & 2 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler):

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
# Simple TIR Compute Expression
storage_dtype = "int8"

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

ir_module = matmul_select_implementation(
M=self.M,
M=M,
N=self.N,
K=K,
in_dtype=storage_dtype,
Expand All @@ -46,7 +52,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down Expand Up @@ -230,6 +236,47 @@ def __post_init__(self):
@dataclass
class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler):

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
# Simple TIR Compute Expression
storage_dtype = "int8"

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

ir_module = matmul_select_implementation(
M=M,
N=self.N,
K=K,
in_dtype=storage_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
layout=layout,
propagate_b=self.weight_transform_kind)

roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)

if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")

def serialze_hints_to_configs(hints: List[Hint]):
configs = []
for hint in hints:
config = self.TLHint.from_roller_hint(hint)
configs.append(config)
return configs

return serialze_hints_to_configs(roller_hints)

def apply_config(
self,
block_row_warps=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
zeros_mode=self.zeros_mode)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedSchedu

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
storage_dtype = "int8"
num_bits = self.num_bits * 2

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

# INT4XINT2 is equal to int8xint4 with reduced shape
# Simple TIR Compute Expression
ir_module = matmul_dequantize_select_implementation(
Expand All @@ -56,7 +62,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
zeros_mode=self.zeros_mode)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
from bitblas import tvm as tvm
from tvm import DataType
import tvm.tl.language as T
from typing import Optional
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
index_to_coordinates, # noqa: F401
)
from bitblas.base.arch import TileDevice
from bitblas.base.roller.hint import Hint
from bitblas.tl.macro_generator import (
INT4TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
from bitblas.ops.common import TransformKind # noqa: F401
from dataclasses import dataclass
from bitblas.base.utils import get_roller_hints_from_func
from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group
from bitblas.ops.general_matmul.tirscript import (
matmul_dequantize_select_implementation,)
from bitblas.ops.general_matmul.tilelang.dequantize.ladder_weight_transform_tensorcore import (
MatmulDequantizeWeightPropagationScheduler,)

Expand All @@ -25,6 +30,60 @@
@dataclass
class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler):

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
storage_dtype = "int8"
num_bits = self.num_bits * 2

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

# INT4XINT2 is equal to int8xint4 with reduced shape
# Simple TIR Compute Expression
ir_module = matmul_dequantize_select_implementation(
M=M,
N=self.N,
K=K,
in_dtype=storage_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
layout=layout,
bit=num_bits,
storage_dtype=self.storage_dtype,
source_format=self.source_format,
with_scaling=self.with_scaling,
with_zeros=self.with_zeros,
group_size=self.group_size,
fast_decoding=self.fast_decoding,
with_bias=self.with_bias,
zeros_mode=self.zeros_mode)

roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)

if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")

for hint in roller_hints:
print(hint)

def serialze_hints_to_configs(hints: List[Hint]):
configs = []
for hint in hints:
config = self.TLHint.from_roller_hint(hint)
configs.append(config)
return configs

return serialze_hints_to_configs(roller_hints)

def apply_config(
self,
block_row_warps: Optional[int] = None,
Expand Down
11 changes: 4 additions & 7 deletions bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from bitblas.builder.wrapper import TIRWrapper, TLWrapper
from bitblas.builder.lib_generator import LibraryGenerator
from bitblas.common import MAX_ERROR_MESSAGE_LENGTH
from bitblas.utils import retrieve_func_from_module
from dataclasses import dataclass
import logging
import re
Expand Down Expand Up @@ -317,6 +318,7 @@ def apply_fast_tuning(
elif self.is_tilelang_backend():
# Finetune the schedule
tuning_configs = self.get_tl_tuning_config(topk=topk)
assert len(tuning_configs) > 0, "No tuning config found for this operator."
_, best = tl_apply_and_build(
func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=parallel_build)
# Return the best Config as Hint
Expand Down Expand Up @@ -368,11 +370,9 @@ def hardware_aware_finetune(
assert (
len(scheduled_mod.get_global_vars()) == 1
), "The optimized module should only have one global variable for default schedule."
assert (
"main" in scheduled_mod
), "The optimized module should have a function named 'main' for default schedule."
default_kernal_name = self.kernel_name_generator.generate(best_hint)
func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
func = retrieve_func_from_module(scheduled_mod).with_attr("global_symbol",
default_kernal_name)
scheduled_ir_module = tvm.IRModule({default_kernal_name: func})
self._update_optimized_mod(scheduled_ir_module)

Expand Down Expand Up @@ -465,9 +465,6 @@ def forward(self, *args):
def __call__(self, *args: Any) -> Any:
return self.forward(*args)

def update_func(self, func: PrimFunc):
self.ir_module["main"] = func

def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None):
if rt_mod is not None:
self.rt_mod = rt_mod
Expand Down
Loading
Loading