Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 300 additions & 3 deletions bitblas/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from dataclasses import dataclass
from typing import Optional

from typing import Optional, List
from contextlib import suppress
from tvm import tir
from tvm.target import Target
from tvm.tir.stmt import ForKind

from ..base import analysis
from ..base.analysis import get_coalesced_veclen
from .base import GPUScheduleRule
from . import utils
from .matmul_analysis import (
Expand All @@ -19,6 +20,7 @@
get_in_out_dtypes,
get_index_map,
normalize_to_matmul,
_collect_producers,
get_reduction_blocks,
)
from .matmul_mma import MatmulTensorizationMMA
Expand Down Expand Up @@ -242,6 +244,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
func: tir.PrimFunc,
config,
) -> tir.Schedule:
if "dequantize_info" in func.attrs:
return self.sch_dequantize_in_register_with_config(func, config)
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
Expand Down Expand Up @@ -278,7 +282,7 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
# Step 2. Get schedule config.
block_row_warps = config.block[0] // (config.thread[0] * config.step[0])
block_col_warps = config.block[1] // (config.thread[1] * config.step[1])
thread_row_tiles = config.thread[1] // (config.step[0] * 2)
thread_row_tiles = config.thread[0] // (config.step[0] * 2)
thread_col_tiles = config.thread[1] // (config.step[1] * 2)
vthread_row_tiles = (config.step[0] * 2) # expand vtrhead to avoid load band conflict
vthread_col_tiles = (config.step[1] * 2) # expand vtrhead to avoid load band conflict
Expand Down Expand Up @@ -370,3 +374,296 @@ def is_trivial_load(block):

sch.decompose_reduction(main_block, ko)
return sch

def sch_dequantize_in_register_with_config( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
config,
) -> tir.Schedule:
'''
For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch.
quantized weight
|
V
dequantized in register
|
V
save into shared memory
|
V
compute
'''
from .intrin import get_lop3_intrin_group

import_source: List[str] = []

sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

reduction_blocks = get_reduction_blocks(sch, blocks)
if reduction_blocks is None:
return None

# in some case conv template will use this rule, but the tile config is not
# analyzed by matmul expr.
if len(config.block) != 2:
logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.")
return None

# Check Dequantize Info
dequantize_info = func.attrs["dequantize_info"]

def check_dequantize_info(dequantize_info):
conditions = []
# currently only support weight only dequantization
conditions.append(len(dequantize_info) == 1)
# TODO(@lei) check if the dequantize value name is weight
return all(conditions)

assert check_dequantize_info(dequantize_info)

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info"

main_block = reduction_blocks[0]

block_stmt = sch.get(main_block)

# dequant must be 'n' 't' 'n' layout for fast decoding.
index_maps = get_index_map(block_stmt, ["n", "t", "n"])
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

# Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 2. Get schedule config.
block_row_warps = config.block[0] // (config.thread[0] * config.step[0])
block_col_warps = config.block[1] // (config.thread[1] * config.step[1])
thread_row_tiles = config.thread[0] // (config.step[0])
thread_col_tiles = config.thread[1] // (config.step[1])
vthread_row_tiles = (config.step[0]) # expand vtrhead to avoid load band conflict
vthread_col_tiles = (config.step[1]) # expand vtrhead to avoid load band conflict
chunk = config.rstep[0]
shared_scope = config.shared_scope

num_ty = block_row_warps
num_tx = block_col_warps

# Step 3. Schedule matmul
BM = block_row_warps * vthread_row_tiles * thread_row_tiles
BN = block_col_warps * vthread_col_tiles * thread_col_tiles

# TODO(lei): this is a hack.
def find_valid_number(k, chunk, magic=16):
# Start with the largest possible number smaller than chunk that is divisible by 16
num = (chunk // magic) * magic
# Iterate downwards to find a number divisible by both 16 and k
while num > 0:
if k % num == 0:
return num
num -= magic

return None # If no such number is found

K = func.buffer_map[func.params[0]].shape[-1]
BK = find_valid_number(K, chunk)

sch.pad_einsum(
main_block,
[1, BM, BN, BK],
)
batch, y, x, k = sch.get_loops(main_block)
by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles])
bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles])
ko, ki = sch.split(k, factors=[None, BK])
sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
by = sch.fuse(batch, by)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
sch.bind(vy, "vthread.y")
sch.bind(vx, "vthread.x")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

def prod(iterable):
return reduce(lambda x, y: x * y, iterable, 1)

l2g = sch.cache_write(main_block, 0, "local")
sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)

def _cooperative_fetch(index, vec_len):
block = sch.cache_read(main_block, index, "shared")
num_loops = len(sch.get_loops(block))
block_local = sch.cache_read(main_block, index, "local")
sch.compute_at(block_local, ki, preserve_unit_loops=True)
sch.compute_at(block, ko, preserve_unit_loops=True)
loops = sch.get_loops(block)[-num_loops:]
_, ty, tx, vec = sch.split(
sch.fuse(*loops),
factors=[None, block_row_warps, block_col_warps, vec_len],
)

auto_inline_producers(sch, block)

def is_trivial_load(block):
# avoid vectorize under global[v2, v1]] shared[v1, v2] case
reads = sch.get(block).reads
writes = sch.get(block).writes
if len(reads) != 1 or len(writes) != 1:
return False
return all(
read.region[-1] == write.region[-1] for read, write in zip(reads, writes))

if is_trivial_load(block):
sch.vectorize(vec)

sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

_, vec = sch.split(
sch.fuse(*sch.get_loops(block_local)[-2:]),
[None, vec_len // prod(config.step)],
)
sch.vectorize(vec)

return block

for i, input_region in enumerate(sch.get(main_block).reads[:1]):
_buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "")
if _buffer_name not in config.cached_tensors:
logger.warning(
f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip."
)
continue

# otherwise cooperative fetch in shared memory.
vectorize = config.vectorize.get(_buffer_name, 1)

_cooperative_fetch(i, vec_len=vectorize)

def decode_fetch_to_shared(block, idx):
# step1. create memory hierarchy
# global -> local -> shared
block_shared = sch.cache_read(block, idx, shared_scope)
sch.compute_at(block_shared, ko, preserve_unit_loops=True)

decode_factor = get_coalesced_veclen(sch.get(block_shared))
_, B_shared_vi, _ = sch.split(
sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor])
block_shared_local = sch.cache_read(block_shared, 0, "local")
# global -> dequantzed_local -> shared
# step2. inline to local block
weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"])
weight_producers = _collect_producers(sch, weight_dequantize_block)
auto_inline_producers(sch, block_shared_local, weight_producers)

# get target dequantize buffer's idx
def get_idx():
# for LUT dequantize, the expr is LUT(w), the idx is 1
# maybe we can use a more general and structural based way
# to analysis the idx
if weight_decode_info["source_format"]["format"] == "nf":
return 1
return 0

b_idx = get_idx()
# global -> prefetch_local -> dequantzed_local -> shared
block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local")

sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True)
sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True)

dequantize_block_local = block_shared_local
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
# pop the scale block
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)

for producer in weight_producers:
with suppress(Exception):
auto_inline_producers(sch, producer)
sch.compute_inline(producer)

# fast type conversion
if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]):
source_bit = weight_decode_info["source_format"]["bits"]
out_dtype = weight_decode_info["target_format"]
lop3_intrin_info = get_lop3_intrin_group(
out_dtype=out_dtype,
storage_dtype=weight_decode_info["storage_dtype"],
source_format=weight_decode_info["source_format"]["format"],
source_bit=source_bit,
with_scaling=weight_decode_info["with_scaling"],
with_zeros=weight_decode_info["with_zeros"],
zeros_mode=weight_decode_info["zeros_mode"],
)
sch.tensorize(
sch.get_loops(dequantize_block_local)[-1],
lop3_intrin_info["compute"],
)
import_source.append(lop3_intrin_info["c_source"])

union_len = (2 + 2)
B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2])

_, B_shared_ty, B_shared_tx = sch.split(B_shared_fused, factors=[None, num_ty, num_tx])
sch.bind(B_shared_tx, "threadIdx.x")
sch.bind(B_shared_ty, "threadIdx.y")
sch.vectorize(sch.get_loops(block_shared)[-1])
sch.vectorize(sch.get_loops(block_shared_local_local)[-1])

# cache small tensors, e.g. LUT
if b_idx:
block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope)
sch.reverse_compute_at(block_shared_lut, bx)
_, B_shared_tx = sch.split(
sch.get_loops(block_shared_lut)[-1], factors=[None, num_tx])
sch.bind(B_shared_tx, "threadIdx.x")
return block_shared_local

_ = decode_fetch_to_shared(main_block, 1)

auto_inline_consumer_chain(sch, l2g)

_, vec = sch.split(
sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)])
sch.vectorize(vec)

sch.decompose_reduction(main_block, ko)
# plan import source
if len(import_source) > 0:
sch.annotate(
ty,
ann_key="pragma_import_c",
ann_val=("\n").join(import_source),
)
return sch
13 changes: 12 additions & 1 deletion bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,15 @@ def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1

def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool:
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
target: Target) -> Union[bool, Dict]:
tags: Dict[str, Union[List[int], int]] = {}
block_stmt = sch.get(block)

# Nvidia Only Support Tensor Core for
# devices greater than 70.
if check_sm_version(target.arch) < 70:
return False
# analysis tensorcore axis
# todo(lei): maybe we can remove this in the future
(write_buffer_region,) = block_stmt.writes
Expand Down Expand Up @@ -612,6 +617,11 @@ def check_last_trait(region: List[Range]):
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
intrin_info["in_dtype"] = in_dtype
intrin_info["out_dtype"] = out_dtype

if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32":
# INT32 Accum TensorCore only supports SM Version > 32.
return False

# if the last dimension is reduce axis, the B is transposed
intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region)
if func.attrs is not None and "input_transform_kind" in func.attrs:
Expand Down Expand Up @@ -666,6 +676,7 @@ def check_last_trait(region: List[Range]):

block_stmt = sch.get(main_block)

# 16 for 16 bits tensor core while 32 for 8bits tensorcore.
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
Expand Down
6 changes: 6 additions & 0 deletions integration/BitNet/eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from tokenization_bitnet import BitnetTokenizer
from transformers import GenerationConfig
import time
import transformers

print(f"transformers version is {transformers.__version__}")

# version must be lower than or equal to 4.40
assert transformers.__version__ <= "4.40.0"

torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")
Expand Down