Skip to content

Commit 0e86db8

Browse files
committed
Add microbenchmark for composable kernel gemm
Add microbenchmark for composable kernel gemm, automatically generated by GPT-4 based on existing code. --------- Co-authored-by: GPT-4
1 parent 7435f10 commit 0e86db8

4 files changed

Lines changed: 323 additions & 2 deletions

File tree

superbench/benchmarks/micro_benchmarks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from superbench.benchmarks.micro_benchmarks.ib_validation_performance import IBBenchmark
2929
from superbench.benchmarks.micro_benchmarks.kernel_launch_overhead import KernelLaunch
3030
from superbench.benchmarks.micro_benchmarks.ort_inference_performance import ORTInferenceBenchmark
31+
from superbench.benchmarks.micro_benchmarks.rocm_composable_kernel_performance import RocmComposableKernelBenchmark
3132
from superbench.benchmarks.micro_benchmarks.rocm_gemm_flops_performance import RocmGemmFlopsBenchmark
3233
from superbench.benchmarks.micro_benchmarks.rocm_memory_bw_performance import RocmMemBwBenchmark
3334
from superbench.benchmarks.micro_benchmarks.sharding_matmul import ShardingMatmul
@@ -64,6 +65,7 @@
6465
'MicroBenchmark',
6566
'MicroBenchmarkWithInvoke',
6667
'ORTInferenceBenchmark',
68+
'RocmComposableKernelBenchmark',
6769
'RocmGemmFlopsBenchmark',
6870
'RocmMemBwBenchmark',
6971
'ShardingMatmul',
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""Module of the ROCm composable kernel GEMM benchmark."""
5+
6+
import os
7+
import re
8+
9+
from superbench.common.utils import logger
10+
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
11+
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark
12+
13+
14+
class RocmComposableKernelBenchmark(BlasLtBaseBenchmark):
15+
"""The composable kernel GEMM benchmark class."""
16+
def __init__(self, name, parameters=''):
17+
"""Constructor.
18+
19+
Args:
20+
name (str): benchmark name.
21+
parameters (str): benchmark parameters.
22+
"""
23+
super().__init__(name, parameters)
24+
25+
self._bin_name = 'ckProfiler'
26+
self._in_types = ['fp32', 'fp16', 'bf16', 'fp8', 'int8']
27+
self._in_type_map = {
28+
'fp16': '1',
29+
'fp32': '0',
30+
'bf16': '2',
31+
'fp8': '4',
32+
'int8': '3',
33+
}
34+
35+
def add_parser_arguments(self):
36+
"""Add the specified arguments."""
37+
super().add_parser_arguments()
38+
39+
self._parser.add_argument(
40+
'--in_types',
41+
type=str,
42+
nargs='+',
43+
default=['fp16'],
44+
required=False,
45+
help='List of input data types, support {}.'.format(' '.join(self._in_types)),
46+
)
47+
self._parser.add_argument(
48+
'--initialization',
49+
type=str,
50+
default='int',
51+
choices=['float', 'int'],
52+
required=False,
53+
help='Initialize matrix data.',
54+
)
55+
self._parser.add_argument(
56+
'--matrixA_layout',
57+
type=str,
58+
default='row',
59+
choices=['row', 'col'],
60+
required=False,
61+
help='Matrix A Layout. RowMajor or ColMajor.',
62+
)
63+
self._parser.add_argument(
64+
'--matrixB_layout',
65+
type=str,
66+
default='row',
67+
choices=['row', 'col'],
68+
required=False,
69+
help='Matrix B Layout. RowMajor or ColMajor.',
70+
)
71+
self._parser.add_argument(
72+
'--check_data',
73+
action='store_true',
74+
required=False,
75+
help='Whether check data correctness.',
76+
)
77+
self._parser.add_argument(
78+
'--splitk',
79+
type=int,
80+
default=None,
81+
required=False,
82+
nargs='+',
83+
help='Split K dimension.',
84+
)
85+
self._parser.add_argument(
86+
'--streamk',
87+
type=int,
88+
default=None,
89+
required=False,
90+
nargs='+',
91+
help='Stream K blocks.',
92+
)
93+
94+
def _preprocess(self):
95+
"""Preprocess/preparation operations before the benchmarking.
96+
97+
Return:
98+
True if _preprocess() succeed.
99+
"""
100+
if not super()._preprocess():
101+
return False
102+
103+
self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)
104+
105+
self._commands = []
106+
self._precision_in_commands = []
107+
matrix_layout = '0'
108+
if self._args.matrixA_layout == 'row' and self._args.matrixB_layout == 'row':
109+
matrix_layout = '0'
110+
elif self._args.matrixA_layout == 'row' and self._args.matrixB_layout == 'col':
111+
matrix_layout = '1'
112+
elif self._args.matrixA_layout == 'col' and self._args.matrixB_layout == 'row':
113+
matrix_layout = '2'
114+
elif self._args.matrixA_layout == 'col' and self._args.matrixB_layout == 'col':
115+
matrix_layout = '3'
116+
if self._args.check_data:
117+
self._args.check_data = '1'
118+
else:
119+
self._args.check_data = '0'
120+
init = 1 if self._args.initialization == 'int' else 2
121+
for (_m, _n, _k, _b, _in_type) in self._shapes_to_run:
122+
params = f'{self._in_type_map[_in_type]}' + \
123+
f' {matrix_layout} {self._args.check_data} {init} 0 1' + \
124+
f' {_m} {_n} {_k} -1 -1 -1'
125+
command = f'{self.__bin_path} gemm {params} {self._args.num_warmup} {self._args.num_steps}'
126+
self._commands.append(command)
127+
logger.info(command)
128+
if self._args.splitk:
129+
if not isinstance(self._args.splitk, list):
130+
self._args.splitk = [self._args.splitk]
131+
for splitk in self._args.splitk:
132+
command = f'{self.__bin_path} gemm_splitk {params} {splitk}' + \
133+
f' {self._args.num_warmup} {self._args.num_steps}'
134+
self._commands.append(command)
135+
logger.info(command)
136+
if self._args.streamk:
137+
if not isinstance(self._args.streamk, list):
138+
self._args.streamk = [self._args.streamk]
139+
for streamk in self._args.streamk:
140+
command = f'{self.__bin_path} gemm_streamk {params} {streamk}' + \
141+
f' {self._args.num_warmup} {self._args.num_steps}'
142+
self._commands.append(command)
143+
logger.info(command)
144+
return True
145+
146+
def _process_raw_result(self, cmd_idx, raw_output):
147+
"""Function to parse raw results and save the summarized results.
148+
149+
self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
150+
151+
Args:
152+
cmd_idx (int): the index of command corresponding with the raw_output.
153+
raw_output (str): raw output string of the micro-benchmark.
154+
155+
Return:
156+
True if the raw output string is valid and result can be extracted.
157+
"""
158+
self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data)
159+
160+
try:
161+
lines = raw_output.splitlines()
162+
index = None
163+
164+
# Find the line containing 'hipblaslt-Gflops'
165+
for i, line in enumerate(lines):
166+
if 'Best Perf' in line:
167+
index = i
168+
break
169+
170+
if index is not None:
171+
# Search the text for each pattern
172+
datatype_match = re.search(r"datatype = (\w+)", line)
173+
m_match = re.search(r"M = (\d+)", line)
174+
n_match = re.search(r"N = (\d+)", line)
175+
k_match = re.search(r"K = (\d+)", line)
176+
flops_match = re.search(r"(\d+\.?\d*) TFlops", line)
177+
178+
# Extract the matched groups
179+
datatype = datatype_match.group(1) if datatype_match else None
180+
m = int(m_match.group(1)) if m_match else None
181+
n = int(n_match.group(1)) if n_match else None
182+
k = int(k_match.group(1)) if k_match else None
183+
flops = float(flops_match.group(1)) if flops_match else None
184+
185+
metric = f'{datatype}_{m}_{n}_{k}_flops'
186+
self._result.add_result(metric, flops)
187+
else:
188+
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
189+
logger.error(
190+
'The result format is invalid - round: {}, benchmark: {}, raw output: {}.'.format(
191+
self._curr_run_index, self._name, raw_output
192+
)
193+
)
194+
return False
195+
196+
except BaseException as e:
197+
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
198+
logger.error(
199+
'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
200+
self._curr_run_index, self._name, raw_output, str(e)
201+
)
202+
)
203+
return False
204+
finally:
205+
if cmd_idx == len(self._commands) - 1:
206+
for metric in self.results:
207+
self.results[metric] = [max(self.results[metric])]
208+
return True
209+
210+
211+
BenchmarkRegistry.register_benchmark('composable-kernel-gemm', RocmComposableKernelBenchmark, platform=Platform.ROCM)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Tests for ROCm composable kernel benchmark."""
5+
6+
import unittest
7+
from types import SimpleNamespace
8+
9+
from tests.helper.testcase import BenchmarkTestCase
10+
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
11+
from superbench.benchmarks.result import BenchmarkResult
12+
13+
14+
class composable_kernelBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
15+
"""Class for composable kernel benchmark test cases."""
16+
@classmethod
17+
def setUpClass(cls):
18+
"""Hook method for setting up class fixture before running tests in the class."""
19+
super().setUpClass()
20+
cls.benchmark_name = 'composable-kernel-gemm'
21+
cls.createMockEnvs(cls)
22+
cls.createMockFiles(cls, ['bin/ckProfiler'])
23+
24+
def get_benchmark(self):
25+
"""Get Benchmark."""
26+
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.ROCM)
27+
return benchmark_cls(self.benchmark_name, parameters='')
28+
29+
def test_composable_kernel_gemm_cls(self):
30+
"""Test composable-kernel-gemm benchmark class."""
31+
for platform in Platform:
32+
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, platform)
33+
if platform is Platform.ROCM:
34+
self.assertIsNotNone(benchmark_cls)
35+
else:
36+
self.assertIsNone(benchmark_cls)
37+
38+
def test_composable_kernel_gemm_command_generation(self):
39+
"""Test composable-kernel-gemm benchmark command generation."""
40+
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.ROCM)
41+
benchmark = benchmark_cls(
42+
self.benchmark_name,
43+
parameters=' --shapes 2,4,8 --in_types fp16 fp32',
44+
)
45+
46+
self.assertTrue(benchmark._preprocess())
47+
self.assertEqual(len(benchmark._args.in_types), len(benchmark._commands))
48+
49+
benchmark = benchmark_cls(
50+
self.benchmark_name,
51+
parameters=' --shapes 2,4,8 --in_types fp16 fp32 --splitk 2 4 --streamk -1',
52+
)
53+
54+
self.assertTrue(benchmark._preprocess())
55+
self.assertEqual(4 * len(benchmark._args.in_types), len(benchmark._commands))
56+
for _t in ['fp16', 'fp32']:
57+
params = f'{benchmark._in_type_map[_t]} 0 0 1 0 1 2 4 8 -1 -1 -1'
58+
command = f'{benchmark._RocmComposableKernelBenchmark__bin_path} gemm {params} {benchmark._args.num_warmup} {benchmark._args.num_steps}'
59+
assert (command in benchmark._commands)
60+
61+
for splitk in [2, 4]:
62+
command = f'{benchmark._RocmComposableKernelBenchmark__bin_path} gemm_splitk {params} {splitk} {benchmark._args.num_warmup} {benchmark._args.num_steps}'
63+
assert (command in benchmark._commands)
64+
65+
command = f'{benchmark._RocmComposableKernelBenchmark__bin_path} gemm_streamk {params} -1 {benchmark._args.num_warmup} {benchmark._args.num_steps}'
66+
assert (command in benchmark._commands)
67+
68+
def test_composable_kernel_gemm_result_parsing(self):
69+
"""Test composable-kernel-gemm benchmark result parsing."""
70+
benchmark = self.get_benchmark()
71+
self.assertTrue(benchmark._preprocess())
72+
benchmark._args = SimpleNamespace(shapes=['8192,8192,8192'], in_types=['fp16'], log_raw_data=False)
73+
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
74+
75+
example_raw_output = """
76+
Perf: 17.0853 ms, 64.3544 TFlops, 23.5673 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B256_Vec8x1x4_512x16x4x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
77+
Perf: 51.8717 ms, 21.1967 TFlops, 7.76248 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B64_Vec8x1x4_16x16x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
78+
Perf: 51.2179 ms, 21.4673 TFlops, 7.86157 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B64_Vec8x1x4_16x16x16x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
79+
Perf: 24.4389 ms, 44.9902 TFlops, 16.4759 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x1x4_16x32x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
80+
Perf: 12.0388 ms, 91.331 TFlops, 33.4464 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x2x4_16x64x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
81+
Perf: 12.8774 ms, 85.3828 TFlops, 31.2681 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x4x4_16x128x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
82+
Perf: 14.7506 ms, 74.54 TFlops, 27.2974 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x8x4_16x256x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
83+
Perf: 12.0325 ms, 91.3782 TFlops, 33.4637 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B256_Vec8x4x4_16x256x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
84+
Perf: 26.055 ms, 42.1996 TFlops, 15.4539 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x1x4_32x16x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
85+
Perf: 13.9292 ms, 78.9358 TFlops, 28.9072 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x1x4_64x16x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
86+
Perf: 8.0511 ms, 136.567 TFlops, 50.0122 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x1x4_128x16x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
87+
Perf: 18.9246 ms, 58.0995 TFlops, 21.2767 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x1x4_256x16x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
88+
Perf: 15.0647 ms, 72.986 TFlops, 26.7283 GB/s, GemmXdlSplitKCShuffle_MNKPadding_RRR_B256_Vec8x1x4_256x16x8x8 LoopScheduler: Default, PipelineVersion: v2, KBatch 2
89+
Best Perf for datatype = f16 ALayout = RowMajor BLayout = RowMajor M = 8192 N = 8192 K = 8192 StrideA = 8192 StrideB = 8192 StrideC = 8192 KBatch = 2 : 2.17246 ms, 506.113 TFlops, 185.344 GB/s, GemmXdlSplitKCShuffle_Default_RRR_B256_Vec8x2x8_256x128x4x8 LoopScheduler: Default, PipelineVersion: v1
90+
"""
91+
# Positive case - valid raw output
92+
self.assertTrue(benchmark._process_raw_result(0, example_raw_output))
93+
self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)
94+
95+
self.assertEqual(2, len(benchmark.result))
96+
self.assertEqual(506.113, benchmark.result['f16_8192_8192_8192_flops'][0])

third_party/Makefile

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ ROCM_PATH ?= /opt/rocm
1212
CUDA_VER ?= $(shell nvcc --version | grep 'release' | awk '{print $$6}' | cut -c2- | cut -d '.' -f1-2)
1313
ROCBLAS_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3)
1414
HIPBLASLT_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3)
15+
COMPOSABLEKERNEL_BRANCH ?= rocm-$(shell dpkg -l | grep 'rocm-dev ' | awk '{print $$3}' | cut -d '.' -f1-3)
1516
ROCM_VER ?= $(shell hipconfig -R | grep -oP '\d+\.\d+\.\d+' || echo "0.0.0")
1617

17-
.PHONY: all cuda_with_msccl cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest cuda_msccl rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt megatron_lm megatron_deepspeed apex_rocm
18+
.PHONY: all cuda_with_msccl cuda rocm common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest cuda_msccl rocm_perftest fio rocm_rccl_tests rocm_rocblas rocm_bandwidthTest gpcnet cuda_gpuburn cpu_stream cpu_hpl directx_amf_encoding_latency directx_amd rocm_hipblaslt megatron_lm megatron_deepspeed apex_rocm rocm_composable_kernel
1819

1920
# Build all targets.
2021
all: cuda rocm
2122
cuda_with_msccl: cuda cuda_msccl
2223
cuda: common cuda_cutlass cuda_bandwidthTest cuda_nccl_tests cuda_perftest gpcnet cuda_gpuburn megatron_lm megatron_deepspeed
23-
rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt megatron_deepspeed apex_rocm
24+
rocm: common rocm_perftest rocm_rccl_tests rocm_rocblas rocm_bandwidthTest rocm_hipblaslt megatron_deepspeed apex_rocm rocm_composable_kernel
2425
cpu: common cpu_perftest
2526
common: cpu_hpl cpu_stream fio
2627
directx_amd: directx_amf_encoding_latency
@@ -120,6 +121,17 @@ rocm_hipblaslt: sb_micro_path
120121
cp -v $(SB_MICRO_PATH)/third_party/hipBLASLt/build/release/clients/staging/hipblaslt-bench $(SB_MICRO_PATH)/bin/; \
121122
fi
122123

124+
# Build composable_kernel.
125+
# Composable Kernel is released with rocm, like rocm-6.0 and so on.
126+
rocm_composable_kernel: sb_micro_path
127+
@if [ ! -e $(SB_MICRO_PATH)/bin/ckProfiler ] && [ -z `which ckProfiler` ]; then \
128+
if [ -d composable_kernel ]; then rm -rf composable_kernel; fi; \
129+
git clone -b ${COMPOSABLEKERNEL_BRANCH} https://github.com/ROCm/composable_kernel; \
130+
cd composable_kernel && mkdir build && cd build; \
131+
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D DTYPES="fp64;fp32;fp16;fp8;bf16;int8" ..; \
132+
make -j ckProfiler install; \
133+
fi
134+
123135
# Build hipBusBandwidth.
124136
# HIP is released with rocm, like rocm-4.2.0 and so on.
125137
# The version we use is the released tag which is consistent with the rocm version in the environment or docker.

0 commit comments

Comments
 (0)