diff --git a/3rdparty/tvm b/3rdparty/tvm index 1fa647dbf..08af76d06 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1fa647dbff6a273cbdf2a6f0a64b3478ba553223 +Subproject commit 08af76d069d9d5906ce85b8a771685812daeecdc diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 0dbbdf96b..1a1418758 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -12,6 +12,7 @@ from bitblas import tvm from tvm.contrib.tar import tar import logging +import threading logger = logging.getLogger(__name__) @@ -24,53 +25,63 @@ class OperatorCache: """ Manages a cache for operator instances (e.g., Matmul, Convolution) based on their configurations. """ + # A lock to synchronize access to the cache + # RLock is used to allow reentrant locking + # As load_from_database calls _load_operator which + # calls _instantiate_and_add_operator + cache_locker = threading.RLock() def __init__(self): self.cache = {} def add(self, config: OperatorConfig, op_inst: Operator): - self.cache[config] = op_inst + with self.cache_locker: + self.cache[config] = op_inst def get(self, config: OperatorConfig): - return self.cache.get(config) + with self.cache_locker: + return self.cache.get(config) def exists(self, config): return config in self.cache def clear(self): - self.cache.clear() + with self.cache_locker: + self.cache.clear() def size(self): return len(self.cache) def save_into_database(self, database_path=None, target=None): - database_path = self._ensure_database_path(database_path) - for config, op_inst in self.cache.items(): - arch_str = self._determine_arch_str(op_inst, target) - arch_path = os.path.join(database_path, arch_str) - self._ensure_directory(arch_path) - hash_str = sha256(repr(config).encode()).hexdigest() - config_path = os.path.join(arch_path, hash_str) - # if the config already exists, skip saving - if os.path.exists(config_path): - continue - self._ensure_directory(config_path) - self._save_operator_config_and_artifact(config, op_inst, config_path) + with self.cache_locker: + database_path = self._ensure_database_path(database_path) + for config, op_inst in self.cache.items(): + arch_str = self._determine_arch_str(op_inst, target) + arch_path = os.path.join(database_path, arch_str) + self._ensure_directory(arch_path) + hash_str = sha256(repr(config).encode()).hexdigest() + config_path = os.path.join(arch_path, hash_str) + # if the config already exists, skip saving + if os.path.exists(config_path): + continue + self._ensure_directory(config_path) + self._save_operator_config_and_artifact(config, op_inst, config_path) def load_from_database(self, database_path, target=None): - if not os.path.exists(database_path): - logger.info( - f"Database path {database_path} does not exist, skipping loading operators from the database" - ) - return - arch_str = self._determine_target_arch_str(target) - arch_path = os.path.join(database_path, arch_str) - if not os.path.exists(arch_path): - logger.info( - f"Target {arch_str} does not exist in the database, skipping loading operators from the database" - ) - return - self._load_operators_from_arch_path(arch_path, target) + with self.cache_locker: + if not os.path.exists(database_path): + logger.info( + f"Database path {database_path} does not exist, skipping loading operators from the database" + ) + return + arch_str = self._determine_target_arch_str(target) + arch_path = os.path.join(database_path, arch_str) + if not os.path.exists(arch_path): + logger.info( + f"Target {arch_str} does not exist in the database, skipping loading operators from the database" + ) + return + self._load_operators_from_arch_path(arch_path, target) def _ensure_database_path(self, database_path): if database_path is None: diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index bc091f910..76d756e96 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -1,14 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm -from tvm import DataType -import tvm.tl.language as T from typing import Optional -from bitblas.tl.utils import ( - get_mma_micro_size, - make_swizzle_layout, -) - from bitblas.ops.base_scheduler import BaseScheduler from dataclasses import dataclass @@ -43,9 +36,7 @@ class MatmulFineGrainSIMTScheduler(BaseScheduler): def with_default_config(self): raise NotImplementedError - def apply_config( - self, - ): + def apply_config(self,): # M, N, K = self.M, self.N, self.K # trans_A, trans_B = self.trans_A, self.trans_B @@ -53,7 +44,6 @@ def apply_config( raise NotImplementedError - def __post_init__(self): # Validate the matrix transpose settings assert self.trans_A is False, "Currently only support Matrix A not transposed" diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 8f9ab4f84..fd3c98ef4 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -5,14 +5,11 @@ import os from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind from concurrent.futures import ThreadPoolExecutor, as_completed -import numpy as np -from typing import List, Tuple, Optional, Dict, Union, Literal, Callable +from typing import List, Tuple, Optional, Dict, Literal from tvm import tir, IRModule from tvm.runtime import Module from tvm.tir import Schedule -from tvm.relax.expr import Function import tvm.tl as tl -import bitblas from bitblas.ops.base_scheduler import BaseScheduler from bitblas.base.arch import CUDA from bitblas.base import Hint @@ -20,11 +17,7 @@ from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags import tempfile -import itertools -from tvm.ir.supply import GlobalVarSupply from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from bitblas.utils.tensor_adapter import ( - np_float2np_bf16,) import logging logger = logging.getLogger(__name__) @@ -67,8 +60,8 @@ def profile(self, data_distribution="uniform"): def _apply_config( - scheduler: BaseScheduler, - config: Dict = None, + scheduler: BaseScheduler, + config: Dict = None, ) -> Optional[IRModule]: """ find rules: @@ -121,6 +114,7 @@ def _build(context) -> str: return idx, None, None config = configs[idx] + assert config is not None @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) def tvm_callback_cuda_postproc(code, _): @@ -128,6 +122,7 @@ def tvm_callback_cuda_postproc(code, _): code = tensor_remove_make_int4(code) code = tensor_remove_make_int2(code) return code + # check only have one function in the module if len(mod.functions) > 1: raise ValueError("Only support one function in the module") @@ -168,12 +163,12 @@ def tvm_callback_cuda_postproc(code, _): continue rt_mod = tvm.runtime.load_module(artifact_path) # Transform Tuning Config to Hint - hint = Hint.from_dict( - { - **{"arch": arch}, - **config, - } - ) + hint = Hint.from_dict({ + **{ + "arch": arch + }, + **config, + }) cpresult = CompileResult(hint, sch, rt_mod) timer_cuda_mod = rt_mod.time_evaluator( rt_mod.entry_name, arch.device, number=num_repeats) @@ -250,11 +245,8 @@ def fast_tune( raise NotImplementedError( "Currently do not support fast tune with none-dynamic range set") if opt_shapes: - for name, shape in opt_shapes.items(): - var = find_var_from_func(func, name) - specilized_func = func.specialize({ - var: shape.astype(var.dtype) - }).with_attr("is_specialized") + raise NotImplementedError( + "Currently do not support fast tune with none-dynamic range set") arch = CUDA(target) @@ -281,4 +273,3 @@ def fast_tune( ) return cpresults, best - diff --git a/setup.py b/setup.py index 5fe71db40..bfc6b3830 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true" PACKAGE_NAME = "bitblas" ROOT_DIR = os.path.dirname(__file__) -MAIN_CUDA_VERSION = "12.1" # BitBLAS only supports Linux platform assert sys.platform.startswith("linux"), "BitBLAS only supports Linux platform (including WSL)." diff --git a/testing/python/cache/test_operator_cache_spin_lock.py b/testing/python/cache/test_operator_cache_spin_lock.py new file mode 100644 index 000000000..983acb85e --- /dev/null +++ b/testing/python/cache/test_operator_cache_spin_lock.py @@ -0,0 +1,126 @@ +import pytest +import os +import torch +import bitblas +import threading +from bitblas import Matmul, MatmulConfig +from bitblas.cache import global_operator_cache +from bitblas import tvm as tvm +from tvm.contrib import utils + +target = bitblas.utils.auto_detect_nvidia_target() +bitblas.set_log_level("DEBUG") + + +def get_codegen_result(ops, target): + code = ops.get_source(target=target) + return code + + +def tune_op_in_thread(thread_id, matmul_config, database_path): + """Each thread tunes the given Matmul operation and tries to save it into the global cache.""" + matmul = Matmul( + config=matmul_config, + target=target, + enable_tuning=False, + ) + print(f"Thread {thread_id}: Starting hardware-aware tuning...") + # matmul.hardware_aware_finetune(topk=20) + success = False + try: + print(f"Thread {thread_id}: Adding operation to global cache...") + global_operator_cache.add(matmul.config, matmul) + + global_operator_cache.save_into_database(database_path, target=target) + assert os.path.exists(database_path), "Database file was not created." + global_operator_cache.clear() + assert global_operator_cache.size() == 0, "Global cache was not cleared properly." + global_operator_cache.load_from_database(database_path, target=target) + assert global_operator_cache.size() > 0, ( + f"Thread {thread_id}: Global cache was not loaded properly as it is empty.") + + success = True + except Exception as hash_error: + print(f"Thread {thread_id}: Error encountered - {hash_error}") + assert success, f"Thread {thread_id}: Failed to add operation to global cache." + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", + [ + (1, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt"), + ], +) +def test_global_cache_save_to_database_multithreaded( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + propagate_a, + propagate_b, + layout, +): + num_threads = 4 + global_operator_cache.clear() + + # For real world scenarios, all workers should share the same database path + tempdir = utils.tempdir() + database_path = str(tempdir.path) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + + # Launch four threads, each tuning the same operation + threads = [] + for thread_id in range(num_threads): + thread = threading.Thread( + target=tune_op_in_thread, args=(thread_id, matmul_config, database_path)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + matmul = global_operator_cache.get(matmul_config) + assert matmul is not None, "Matmul operation not found in cache after reload." + + # Verify that the operation produces correct results + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda()) + inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda()) + ref_result = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + + permuted_inputs = [] + if matmul.input_transform is not None: + permuted_inputs.append(matmul.input_transform(inputs[0].cpu()).cuda()) + else: + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) + else: + permuted_inputs.append(inputs[1]) + + bitblas_output = matmul(*permuted_inputs) + torch.testing.assert_close(bitblas_output, ref_result, rtol=1e-2, atol=1e-2) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main()