Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ca42750
fix for relax
LeiWang1999 Dec 8, 2024
58fa7bf
lint fix
LeiWang1999 Dec 8, 2024
8275513
save import bitblas time
LeiWang1999 Dec 10, 2024
fb7de9b
bug fix for tl backend
LeiWang1999 Dec 10, 2024
02cf643
support input transform_kind
LeiWang1999 Dec 11, 2024
65fb3b4
hint identifier
LeiWang1999 Dec 11, 2024
ad7bc1c
annotate hint type for dequantize
LeiWang1999 Dec 11, 2024
d635713
enhance swizzling
LeiWang1999 Dec 12, 2024
a3e97de
Enhance for hardware aware tuning
LeiWang1999 Dec 12, 2024
bdbc685
test fix
LeiWang1999 Dec 12, 2024
e30b64f
remove pad factor
LeiWang1999 Dec 13, 2024
3b2646a
introduce legalize dyanmic pass
LeiWang1999 Dec 13, 2024
b44e42f
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lega…
LeiWang1999 Dec 13, 2024
9462884
update 3rdparty
LeiWang1999 Dec 16, 2024
d662748
testfix
LeiWang1999 Dec 16, 2024
8c05d7b
test code commit
LeiWang1999 Dec 16, 2024
cdd0753
enhance typing and fix test for int4 dequantize gemm
LeiWang1999 Dec 16, 2024
b9c343c
lint fix
LeiWang1999 Dec 16, 2024
bf6903a
TEST FIX
LeiWang1999 Dec 16, 2024
ab0fef2
lint fix
LeiWang1999 Dec 16, 2024
f5e036c
Merge branch 'main' of https://github.com/microsoft/BitBLAS into chan…
LeiWang1999 Dec 16, 2024
ee770d4
Bugfix for bias
LeiWang1999 Dec 16, 2024
7a45262
lint fix
LeiWang1999 Dec 16, 2024
c48302d
lint fix
LeiWang1999 Dec 16, 2024
0b11dfe
test fix
LeiWang1999 Dec 16, 2024
6d1a7e4
Implement Bias
LeiWang1999 Dec 17, 2024
e729caa
fallback nf to tir implementation.
LeiWang1999 Dec 17, 2024
d79cefb
Enhance contiguous batching performance
LeiWang1999 Dec 17, 2024
95d0fc0
separate benchmark op schedule from benchamrk matmul schedule
LeiWang1999 Dec 18, 2024
c2146e8
Implement TileLang NF4
LeiWang1999 Dec 18, 2024
2fd4682
lint fix
LeiWang1999 Dec 18, 2024
9fd40bd
Merge branch 'main' of https://github.com/microsoft/BitBLAS into chan…
LeiWang1999 Dec 18, 2024
43a0220
lint fix
LeiWang1999 Dec 18, 2024
13bea47
test fix
LeiWang1999 Dec 18, 2024
f5e2172
test fix
LeiWang1999 Dec 18, 2024
f5a42a3
remove legacy splitk test
LeiWang1999 Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 172 additions & 94 deletions benchmark/operators/benchmark_bitblas_matmul.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,165 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas

from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas import Matmul, MatmulConfig
import argparse
import json
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.arch import CUDA
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from bitblas.base.utils import apply_and_build

bitblas.set_log_level("DEBUG")
# Initialize the parser
parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.")

# Add arguments to the parser

parser.add_argument(
"--target",
type=str,
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.")
help="Specify the target device for benchmarking.",
)

parser.add_argument(
"--backend",
type=str,
default="tir",
choices=["tir", "tl"], # Replace with actual modes if applicable
help="Specify the mode for calculating zeros.")

parser.add_argument("--verbose", type=bool, default=True, help="Enable verbose logging.")

# [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode]
default_test_shapes = json.dumps([
# ["MatmulConfig", "Matmul", [1, 16384, 16384, "float16", "int4", "float16", "float16", "nt", False, None, False, False, None]]
[
"MatmulConfig", "Matmul",
[
16384, 16384, 16384, "float16", "float16", "float16", "float16", "nt", False, None,
False, False, None
]
]
])
"--M",
type=int,
default=16384,
help="Number of rows in matrix A.",
)

parser.add_argument(
"--test_shapes",
"--N",
type=int,
default=16384,
help="Number of rows in matrix A.",
)

parser.add_argument(
"--K",
type=int,
default=16384,
help="Number of rows in matrix A.",
)

parser.add_argument(
"--A_dtype",
type=str,
default="float16",
choices=[
"float16",
"float32",
"float64",
"int32",
"int8",
], # Assuming these are the valid choices
help="Data type of activation A.",
)
parser.add_argument(
"--W_dtype",
type=str,
default="int4",
choices=[
"float16",
"float32",
"float64",
"int32",
"int8",
"int4",
"int2",
"int1",
"nf4",
"fp4_e2m1",
], # Assuming these are the valid choices
help="Data type of weight W.",
)
parser.add_argument(
"--accum_dtype",
type=str,
default=default_test_shapes,
help="JSON string defining test shapes. Example format: '[[\"MatmulConfig\", \"Matmul\", [1,16384,16384,\"float16\",\"int4\",\"float16\",\"float16\",\"nt\",false,null,false,false,null]]]'"
default="float16",
choices=["float16", "int32"], # Assuming these are the valid choices
help="Data type for accumulation.",
)
parser.add_argument(
"--out_dtype",
type=str,
default="float16",
choices=[
"float16",
"float32",
"int32",
"int8",
], # Assuming these are the valid choices
help="Data type for output.",
)
parser.add_argument(
"--layout",
type=str,
default="nt",
choices=["nt", "nn"], # Assuming these are the valid choices
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
)
parser.add_argument("--with_bias", action="store_true", help="Include bias in the benchmark.")
parser.add_argument(
"--with_scaling",
action="store_true",
help="Include scaling factor in the quantization.",
)
parser.add_argument(
"--group_size", type=int, default=None, help="Group size for grouped quantization.")
parser.add_argument("--with_zeros", action="store_true", help="Include zeros in the quantization.")
parser.add_argument(
"--zeros_mode",
type=str,
default=None,
choices=[
"original",
"rescale",
"quantized",
], # Replace with actual modes if applicable
help="Specify the mode for calculating zeros.",
)

# Parse the arguments
args = parser.parse_args()

# Assign arguments to variables
target = args.target
backend = args.backend
verbose = args.verbose

parsed_test_shapes = json.loads(args.test_shapes)
name_to_class = {"MatmulConfig": MatmulConfig, "Matmul": Matmul}

test_shapes = []
for item in parsed_test_shapes:
config_class_name, operator_class_name, input_args = item
config_class = name_to_class[config_class_name]
operator_class = name_to_class[operator_class_name]
test_shapes.append((config_class, operator_class, tuple(input_args)))
M, N, K = args.M, args.N, args.K
group_size = args.group_size
A_dtype = args.A_dtype
W_dtype = args.W_dtype
accum_dtype = args.accum_dtype
out_dtype = args.out_dtype
layout = args.layout
with_bias = args.with_bias
group_size = args.group_size
with_scaling = args.with_scaling
with_zeros = args.with_zeros
zeros_mode = args.zeros_mode

test_shapes = [
# square test
(
MatmulConfig,
Matmul,
(
M,
N,
K,
A_dtype,
W_dtype,
out_dtype,
accum_dtype,
layout,
with_bias,
group_size,
with_scaling,
with_zeros,
zeros_mode,
),
),
]

benchmark_sets = []
benchmark_sets.extend(test_shapes)
Expand All @@ -71,59 +169,39 @@
benchmark_results = {}
for config, operator, input_args in benchmark_sets:
config = config(*input_args)
print(f"Running benchmark for {operator.__name__} with config: {config}")
op_inst = operator(config, target=target, enable_tuning=True, backend=backend)
kernel_latency = op_inst.profile_latency()
if op_inst.input_transform is not None:
kernel_latency += op_inst.ladder_permutate_a.profile_latency()

print("Time cost of {} is: {:.3f} ms".format(str(config), kernel_latency))

if verbose:
print(op_inst.scheduled_ir_module)
print(op_inst.get_source())

profile_config = {
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
"BitBLAS_top20_latency": kernel_latency,
}
}

benchmark_results.update(profile_config)

# Define headers for the table
headers = [
"PrimFunc",
"Input Arguments",
"BitBLAS Top20 Latency",
]

col_widths = [0, 0, 0]
for config, values in benchmark_results.items():
args = config.split("-")
func_name = args[0]
input_args = "-".join(args[1:])
col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0])
col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1]))
col_widths[2] = max(
max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2,
col_widths[2])
break

for i, header in enumerate(headers):
headers[i] = header.ljust(col_widths[i])

print("".join(headers))

print("-" * sum(col_widths))

for config, values in benchmark_results.items():
args = config.split("-")
func_name = args[0]
input_args = "-".join(args[1:])
row = [
func_name,
input_args,
f"{values['BitBLAS_top20_latency']:.3f} ms",
]
print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]))
matmul = operator(config, target=target, enable_tuning=False)
func = matmul.prim_func
arch = CUDA(target)
policy = DefaultPolicy(func=func, arch=arch)
try:
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)
except Exception:
tags = None
if tags:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)

configs = policy.emit_config(20)
static_configs = []
for config in configs:
static_config = config
static_config.shared_scope = "shared"
static_configs.append(static_config)
dynamic_configs = []
for config in configs:
dynamic_config = config
dynamic_config.shared_scope = "shared.dyn"
dynamic_configs.append(dynamic_config)

_, best_static = apply_and_build(func, static_configs, arch, parallel_build=True)

_, best_dynamic = apply_and_build(func, dynamic_configs, arch, parallel_build=True)
benchmark_results[input_args] = (
best_static.latency,
best_dynamic.latency,
best_static.latency - best_dynamic.latency,
)

for key, value in benchmark_results.items():
print(
f"Input arguments: {key}, Static latency: {value[0]}, Dynamic latency: {value[1]}, Difference: {value[2]}"
)
Loading
Loading