diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index cb50ce702..555b1fcaf 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -5,7 +5,7 @@ on: types: [created] jobs: - benchmark: + benchmark_base: if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') runs-on: self-hosted @@ -17,7 +17,13 @@ jobs: - name: Get base branch commit ID id: get_base_commit - run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV + run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" > base_commit_id.txt + + - name: Upload base commit ID + uses: actions/upload-artifact@v3 + with: + name: base-commit-id + path: base_commit_id.txt - name: Set up Python uses: actions/setup-python@v2 @@ -51,7 +57,18 @@ jobs: - name: Get PR branch commit ID id: get_pr_commit - run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV + run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" > pr_commit_id.txt + + - name: Upload PR commit ID + uses: actions/upload-artifact@v3 + with: + name: pr-commit-id + path: pr_commit_id.txt + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' - name: Create virtual environment run: python -m venv bitblas_benchmark @@ -73,17 +90,49 @@ jobs: cd benchmark/operators python ./benchmark_ops_matmul.py + benchmark_compare: + if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') + needs: [benchmark_base, benchmark_head] + runs-on: self-hosted + + steps: + - name: Download commit IDs + uses: actions/download-artifact@v3 + with: + name: base-commit-id + path: . + + - name: Download PR commit ID + uses: actions/download-artifact@v3 + with: + name: pr-commit-id + path: . + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Create virtual environment + run: python -m venv bitblas_benchmark + + - name: Activate virtual environment and install dependencies + run: | + source bitblas_benchmark/bin/activate + python -m pip install --upgrade pip + if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi + - name: Compare benchmark results run: | source bitblas_benchmark/bin/activate cd benchmark/operators - python ./compare_benchmark.py --base ${{ env.BASE_COMMIT_ID }} --head ${{ env.PR_COMMIT_ID }} 2>&1 | tee compare_results.txt + python ./compare_benchmark.py --base $(cat base_commit_id.txt) --head $(cat pr_commit_id.txt) 2>&1 | tee compare_results.txt - name: Authenticate GitHub CLI env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - gh auth login --with-token <<< $GITHUB_TOKEN + echo "${{ secrets.GITHUB_TOKEN }}" | gh auth login --with-token - name: Post benchmark results run: | diff --git a/3rdparty/tvm b/3rdparty/tvm index d9391a502..8dff258d2 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d9391a502b5544722eb67c4a0c4dff49a3476c06 +Subproject commit 8dff258d2837b2c0d24619ebf26dd596b2291912 diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index db83be28a..723cf035b 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -38,26 +38,109 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): "accum_dtype": "int32", "out_dtype": "int8", }, - "FP16xINT4_ACCINT32_NT": { + "FP16xUINT4_ACCFP16_NT": { "A_dtype": "float16", - "W_dtype": "int4", + "W_dtype": "uint4", "accum_dtype": "float16", }, + "FP16xUINT2_ACCFP16_NT": { + "A_dtype": "float16", + "W_dtype": "uint2", + "accum_dtype": "float16", + }, + "INT8xUINT2_ACCINT32_NT": { + "A_dtype": "int8", + "W_dtype": "uint2", + "accum_dtype": "int32", + "out_dtype": "int8", + }, } CURRENT_COMMIT_ID = get_commit_id() + def prepare_set_group_4x(self, name: str, M, N, K) -> List: + return [ + self.generate_op_unit(self.generate_operator_config(name, 1, N, K)), + self.generate_op_unit(self.generate_operator_config(name, M, N, K)), + self.generate_op_unit( + self.generate_operator_config(name, [1, M], N, K), + dynamic_profiling_shape={"m": 1}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, M], N, K), + dynamic_profiling_shape={"m": M}, + ), + ] + + def prepare_set_group_llm(self, name: str, N, K) -> List: + return [ + self.generate_op_unit(self.generate_operator_config(name, 1, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 16, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 32, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 64, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 128, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 2048, N, K)), + self.generate_op_unit( + self.generate_operator_config(name, [1, 16], N, K), + dynamic_profiling_shape={"m": 1}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 32], N, K), + dynamic_profiling_shape={"m": 32}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 64], N, K), + dynamic_profiling_shape={"m": 64}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 128], N, K), + dynamic_profiling_shape={"m": 128}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 2048], N, K), + dynamic_profiling_shape={"m": 2048}, + ), + ] + def prepare_benchmark_sets(self): """Prepare benchmark sets.""" self.add_benchmark_set( "FP16xFP16_ACCFP16_NT", [ - self.generate_op_unit( - self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),), - self.generate_op_unit( - self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), - dynamic_profiling_shape={"M": 1024}, - ), + *self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "INT8xINT8_ACCINT32_NT", + [ + *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672), ], ) @@ -168,15 +251,15 @@ def legalize_shape(M, N, K, dyn_prof_shape): M: The M dimension (can be an int or a tuple). N: The N dimension (must be an int). K: The K dimension (must be an int). - dyn_prof_shape: The dynamic profiling shape (dict with 'M' key if M is dynamic). + dyn_prof_shape: The dynamic profiling shape (dict with "m" key if M is dynamic). Returns: A string representing the shape in either 'M-N-K' or 'N-K_M' format. """ if isinstance(M, int): return f"{M}-{N}-{K}" - elif dyn_prof_shape and "M" in dyn_prof_shape: - return f"{N}-{K}_{dyn_prof_shape['M']}" + elif dyn_prof_shape and "m" in dyn_prof_shape: + return f"{N}-{K}_{dyn_prof_shape['m']}" else: # Calculate the average of tuple M opt_m = sum(M) / len(M) @@ -195,7 +278,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): f"{(2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12):.3f}" if latency else "N/A") latency_str = "N/A" if latency is None else f"{latency:.3f}" - tuning_time_str = ("N/A" if tuning_time is None else f"{tuning_time:.3f}") + tuning_time_str = "N/A" if tuning_time is None else f"{tuning_time:.3f}" table_data.append([shape, latency_str, throughput, tuning_time_str]) diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py index 080d49dca..c45a5a680 100644 --- a/benchmark/operators/compare_benchmark.py +++ b/benchmark/operators/compare_benchmark.py @@ -55,23 +55,42 @@ def legalize_shape(M, N, K, dyn_prof_shape): sum(op_config.M) / len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) - base_latency = base.benchmark_results[name][i][0] + try: + base_latency = base.benchmark_results[name][i][0] + except IndexError: + print(f"Operator {name} not found in benchmark sets") + base_latency = None + if latency is not None: throughput = (2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12) - base_throughput = (2 * benchmark_M * op_config.N * op_config.K / - (base_latency * 1e-3) / 1e12) - throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" + if base_latency is not None: + base_throughput = (2 * benchmark_M * op_config.N * op_config.K / + (base_latency * 1e-3) / 1e12) + throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" + else: + throughput = f"{throughput:.3f}" else: throughput = "N/A" - if base_latency is not None: - latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}" + if latency is not None: + if base_latency is not None: + latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}" + else: + latency_str = f"{latency:.3f}" else: latency_str = "N/A" - base_tuning_time = base.benchmark_results[name][i][1] + try: + base_tuning_time = base.benchmark_results[name][i][1] + except IndexError: + print(f"Operator {name} not found in benchmark sets") + base_tuning_time = None + if tuning_time is not None: - tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}" + if base_tuning_time is not None: + tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}" + else: + tuning_time_str = f"{tuning_time:.3f}" else: tuning_time_str = "N/A" @@ -95,6 +114,8 @@ def legalize_shape(M, N, K, dyn_prof_shape): ) args = parser.parse_args() + print(f"Comparing base commit {args.base} with head commit {args.head}") + base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.base) head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.head) diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 1596b3c86..4bdbbed79 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -51,12 +51,14 @@ def __init__(self, config, sch, mod: Module): self.mod = mod self.code = mod.imported_modules[0].get_source() if mod else None self.latency = 1e9 - self.profile_tensors = [] self.time_evaluator = None - def profile(self): - profile_tensors = self.profile_tensors - return self.time_evaluator(*profile_tensors).mean * 1e3 + def profile(self, data_distribution="uniform"): + func = self.sch.mod["main"] + device = self.config.arch.device + profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution) + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + return latency def _apply_config( @@ -172,7 +174,6 @@ def apply_and_build_parallel(func, data_distribution="uniform") -> CompileResult: cpresults = [] - profile_tensors = get_dummy_input_arrays(func, arch.device, distribution=data_distribution) max_workers = min(len(configs), os.cpu_count(), max_workers) # apply config in thread parallel @@ -242,7 +243,6 @@ def tvm_callback_cuda_postproc(code, _): cpresult = CompileResult(config, sch, rt_mod) timer_cuda_mod = rt_mod.time_evaluator( rt_mod.entry_name, arch.device, number=num_repeats) - cpresult.profile_tensors = profile_tensors cpresult.time_evaluator = timer_cuda_mod cpresult.code = code cpresults.append(cpresult) @@ -256,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _): for cpresult in cpresults: config = cpresult.config try: - latency = cpresult.profile() + latency = cpresult.profile(data_distribution=data_distribution) except Exception as e_mesg: logger.debug(f"Evaluation with config failed {e_mesg}") continue diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index c5e7852e3..f59ca34ee 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -108,6 +108,8 @@ def run_benchmark( latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) + op_inst.cleanup() + return latency, tuning_time @abstractmethod diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 7ed8fbc39..184da0b0a 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -306,6 +306,11 @@ def dispatch_tir(self, def _alloc_workspace(self): return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + def _free_workspace(self): + # release the workspace if it is None + if self.workspace is not None: + self.workspace = None + def _assign_ladder_permutate_a(self, target: Target, enable_tuning: bool): ladder_permutate_a = None if self.propagate_a: @@ -534,6 +539,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any: return self.forward(*args, **kwds) + def cleanup(self): + self._free_workspace() + @property def M(self): return self.config.M diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py index af0370294..e515a264c 100644 --- a/bitblas/ops/matmul.py +++ b/bitblas/ops/matmul.py @@ -209,7 +209,6 @@ def var_warpper(v, m): [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), device=device, )) - self.profile_tensors = profile_tensors latency = self.time_evaluator(*profile_tensors).mean * 1e3 benchmark_latencies.append({"m": m, "latency": latency}) # ms diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 8617d70b9..d35476ee5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -47,7 +47,6 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.optimized_func = None self.rt_mod = None self.time_evaluator = None - self.profile_tensors = None self.arch = get_arch(target) if target else None self.dynamic_range = None self.pass_context: Dict = {} @@ -262,7 +261,6 @@ def map_numpy_type(intype): [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) - self.profile_tensors = profile_tensors return profile_tensors def profile_latency(self, dynamic_symbolic_constraints: Optional[Dict] = None) -> str: @@ -270,6 +268,9 @@ def profile_latency(self, dynamic_symbolic_constraints: Optional[Dict] = None) - dynamic_symbolic_constraints = {} profile_tensors = self.get_profile_tensors(dynamic_symbolic_constraints) latency = self.time_evaluator(*profile_tensors).mean * 1e3 + # release the memory + for tensor in profile_tensors: + del tensor return latency def _tensor_adapter(self, tensor, device): @@ -325,6 +326,9 @@ def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.lib.init() # TODO: update the lib code from srcpath + def cleanup(self): + raise NotImplementedError + @abstractmethod def _select_implementation(self) -> IRModule: pass