diff --git a/python/setup.py b/python/setup.py index a84b896fd..999933281 100644 --- a/python/setup.py +++ b/python/setup.py @@ -29,7 +29,7 @@ import pybind11 from build_helpers import get_base_dir, get_cmake_dir -import setup_helper as helper +from setup_tools import setup_helper as helper @dataclass @@ -400,7 +400,6 @@ def run(self): cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor")) if (cmake_major, cmake_minor) < (3, 18): raise RuntimeError("CMake >= 3.18.0 is required") - for ext in self.extensions: self.build_extension(ext) @@ -432,7 +431,6 @@ def build_extension(self, ext): thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) - ext_base_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) # create build directories if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) @@ -449,6 +447,7 @@ def build_extension(self, ext): "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] + cmake_args += helper.get_backend_cmake_args(build_ext=self) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -472,7 +471,6 @@ def build_extension(self, ext): "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", - f"-DCMAKE_INSTALL_PREFIX={ext_base_dir}", ] # Note that asan doesn't work with binaries that use the GPU, so this is @@ -515,6 +513,7 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) subprocess.check_call(["cmake", "--install", "."], cwd=cmake_dir) + helper.install_extension(build_ext=self) nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") @@ -652,6 +651,7 @@ class plugin_install(install): def run(self): add_links() install.run(self) + helper.post_install() class plugin_develop(develop): @@ -659,6 +659,7 @@ class plugin_develop(develop): def run(self): add_links() develop.run(self) + helper.post_install() class plugin_bdist_wheel(bdist_wheel): @@ -666,6 +667,7 @@ class plugin_bdist_wheel(bdist_wheel): def run(self): add_links() bdist_wheel.run(self) + helper.post_install() class plugin_egginfo(egg_info): @@ -673,6 +675,7 @@ class plugin_egginfo(egg_info): def run(self): add_links() egg_info.run(self) + helper.post_install() # TODO: package_data_tools diff --git a/python/setup_tools/__init__.py b/python/setup_tools/__init__.py new file mode 100644 index 000000000..c3411313f --- /dev/null +++ b/python/setup_tools/__init__.py @@ -0,0 +1,4 @@ +from . import setup_helper +from . import utils + +__all__ = ["setup_helper", "utils"] diff --git a/python/setup_helper.py b/python/setup_tools/setup_helper.py similarity index 72% rename from python/setup_helper.py rename to python/setup_tools/setup_helper.py index ebc371e8c..f9a40a0b7 100644 --- a/python/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -8,35 +8,19 @@ import urllib.request from pathlib import Path import hashlib -from dataclasses import dataclass +from distutils.sysconfig import get_python_lib +from . import utils -flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() -flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() -use_triton_shared = False -necessary_third_party = ["" if flagtree_backend == "tsingmicro" else "flir"] -default_backends = ["nvidia", "amd"] extend_backends = [] +default_backends = ["nvidia", "amd"] +plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"] ext_sourcedir = "triton/_C/" - - -@dataclass -class FlagTreeBackend: - name: str - url: str - tag: str - - -flagtree_backend_info = { - "flir": - FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git", - tag="e72b83ba46a5a9dd6466c7102f93fd600cde909e"), - "triton_shared": - FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", - tag="5842469a16b261e45a2c67fbfc308057622b03ee"), - "cambricon": - FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", - tag="00f51c2e48a943922f86f03d58e29f514def646d"), -} +flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() +flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() +offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF") +device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} +flagtree_backends = utils.flagtree_backends +backend_utils = utils.activate(flagtree_backend) set_llvm_env = lambda path: set_env({ 'LLVM_INCLUDE_DIRS': Path(path) / "include", @@ -45,6 +29,101 @@ class FlagTreeBackend: }) +def install_extension(*args, **kargs): + try: + backend_utils.install_extension(*args, **kargs) + except Exception: + pass + + +def get_backend_cmake_args(*args, **kargs): + try: + return backend_utils.get_backend_cmake_args(*args, **kargs) + except Exception: + return [] + + +def get_device_name(): + return device_mapping[flagtree_backend] + + +def get_extra_packages(): + packages = [] + try: + packages = backend_utils.get_extra_install_packages() + except Exception: + packages = [] + return packages + + +def get_package_data_tools(): + package_data = ["compile.h", "compile.c"] + try: + package_data += backend_utils.get_package_data_tools() + except Exception: + package_data + return package_data + + +def git_clone(lib, lib_path): + import git + MAX_RETRY = 4 + print(f"Clone {lib.name} into {lib_path} ...") + retry_count = MAX_RETRY + while (retry_count): + try: + repo = git.Repo.clone_from(lib.url, lib_path) + if lib.tag is not None: + repo.git.checkout(lib.tag) + sub_triton_path = Path(lib_path) / "triton" + if os.path.exists(sub_triton_path): + shutil.rmtree(sub_triton_path) + print(f"successfully clone {lib.name} into {lib_path} ...") + return True + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") + return False + + +def dir_rollback(deep, base_path): + while (deep): + base_path = os.path.dirname(base_path) + deep -= 1 + return Path(base_path) + + +def download_flagtree_third_party(name, condition, required=False, hock=None): + if not condition: + return + backend = None + for _backend in flagtree_backends: + if _backend.name in name: + backend = _backend + break + if backend is None: + return backend + base_dir = dir_rollback(3, __file__) / "third_party" + prelib_path = Path(base_dir) / name + lib_path = Path(base_dir) / _backend.name + if not os.path.exists(prelib_path) and not os.path.exists(lib_path): + succ = git_clone(lib=backend, lib_path=prelib_path) + if not succ and required: + raise RuntimeError("Bad network ! ") + else: + print(f'Found third_party {backend.name} at {lib_path}\n') + + if callable(hock): + hock(third_party_base_dir=base_dir, backend=backend, default_backends=default_backends) + + +def post_install(): + try: + backend_utils.post_install() + except Exception: + pass + + class FlagTreeCache: def __init__(self): @@ -211,8 +290,13 @@ class CommonUtils: @staticmethod def unlink(): - cur_path = os.path.dirname(__file__) - backends_dir_path = Path(cur_path) / "triton" / "backends" + cur_path = dir_rollback(2, __file__) + if "editable_wheel" in sys.argv: + installation_dir = cur_path + else: + installation_dir = get_python_lib() + backends_dir_path = Path(installation_dir) / "triton" / "backends" + # raise RuntimeError(backends_dir_path) if not os.path.exists(backends_dir_path): return for name in os.listdir(backends_dir_path): @@ -230,15 +314,15 @@ def unlink(): def skip_package_dir(package): if 'backends' in package or 'profiler' in package: return True - if flagtree_backend in ['cambricon']: - if package not in ['triton', 'triton/_C']: - return True - return False + try: + return backend_utils.skip_package_dir(package) + except Exception: + return False @staticmethod def get_package_dir(packages): package_dict = {} - if flagtree_backend and flagtree_backend not in ("cambricon", "aipu", "tsingmicro"): + if flagtree_backend and flagtree_backend not in plugin_backends: connection = [] backend_triton_path = f"../third_party/{flagtree_backend}/python/" for package in packages: @@ -247,70 +331,20 @@ def get_package_dir(packages): pair = (package, f"{backend_triton_path}{package}") connection.append(pair) package_dict.update(connection) + try: + package_dict.update(backend_utils.get_package_dir()) + except Exception: + pass return package_dict - @staticmethod - def download_third_party(): - import git - MAX_RETRY = 4 - global use_triton_shared, flagtree_backend - third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party" - - def git_clone(lib, lib_path): - global use_triton_shared - print(f"Clone {lib.name} into {lib_path} ...") - retry_count = MAX_RETRY - while (retry_count): - try: - repo = git.Repo.clone_from(lib.url, lib_path) - repo.git.checkout(lib.tag) - if lib.name in flagtree_backend_info: - sub_triton_path = Path(lib_path) / "triton" - if os.path.exists(sub_triton_path): - shutil.rmtree(sub_triton_path) - print(f"successfully clone {lib.name} into {lib_path} ...") - return - except Exception: - retry_count -= 1 - print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") - - print(f"Unable to clone third_party {lib.name}") - if lib.name in necessary_third_party: - use_triton_shared = False # TODO - print(f"\n\t{lib.name} is compiled by default, but for " - "some reason we couldn't download triton_shared\n" - "as third_party (most likely for network reasons), " - "so we couldn't compile triton_shared\n") - - third_partys = [] - if flagtree_backend != "tsingmicro": - third_partys.append(flagtree_backend_info["flir"]) - if os.environ.get("USE_TRITON_SHARED", "ON") == "ON": - third_partys.append(flagtree_backend_info["triton_shared"]) - else: - use_triton_shared = False - if flagtree_backend in flagtree_backend_info: - third_partys.append(flagtree_backend_info[flagtree_backend]) - - for lib in third_partys: - lib_path = Path(third_party_base_dir) / lib.name - if not os.path.exists(lib_path): - git_clone(lib=lib, lib_path=lib_path) - else: - print(f'Found third_party {lib.name} at {lib_path}\n') - def handle_flagtree_backend(): global ext_sourcedir if flagtree_backend: - print(f"flagtree_backend is {flagtree_backend}") + print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m") extend_backends.append(flagtree_backend) - if "editable_wheel" in sys.argv and flagtree_backend not in ("aipu", "tsingmicro"): + if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends: ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" - if flagtree_backend != "tsingmicro": - default_backends.append("flir") - if use_triton_shared: - default_backends.append("triton_shared") def set_env(env_dict: dict): @@ -322,8 +356,18 @@ def check_env(env_val): return os.environ.get(env_val, '') != '' -CommonUtils.download_third_party() +download_flagtree_third_party("triton_shared", hock=utils.default.precompile_hock, condition=(not flagtree_backend)) + +download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"), + hock=utils.ascend.precompile_hock, required=True) + +download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True) + +download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock, + required=True) + handle_flagtree_backend() + cache = FlagTreeCache() # iluvatar @@ -375,6 +419,15 @@ def check_env(env_val): post_hock=set_llvm_env, ) +# ascend +cache.store( + file="ascend-llvm-b5cc222d-ubuntu-arm64", + condition=("ascend" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + # aipu cache.store( file="aipu-llvm-a66376b0-ubuntu-x64", diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py new file mode 100644 index 000000000..00c86e5a4 --- /dev/null +++ b/python/setup_tools/utils/__init__.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from pathlib import Path +import importlib.util +import os +from . import ascend +from . import aipu +from . import default + + +@dataclass +class FlagTreeBackend: + name: str + url: str + tag: str = None + + +flagtree_backends = ( + FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", + tag="5842469a16b261e45a2c67fbfc308057622b03ee"), + FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", + tag="00f51c2e48a943922f86f03d58e29f514def646d"), + FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git", + tag="e72b83ba46a5a9dd6466c7102f93fd600cde909e"), + FlagTreeBackend( + name="ascend", + url="https://gitee.com/ascend/triton-ascend.git", + ), +) + + +def activate(backend, suffix=".py"): + if not backend: + return + module_path = Path(os.path.dirname(__file__)) / backend + module_path = str(module_path) + suffix + spec = importlib.util.spec_from_file_location("module", module_path) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception: + pass + return module + + +__all__ = ["ascend", "aipu", "default", "activate"] diff --git a/python/setup_tools/utils/aipu.py b/python/setup_tools/utils/aipu.py new file mode 100644 index 000000000..f4a81f6a5 --- /dev/null +++ b/python/setup_tools/utils/aipu.py @@ -0,0 +1,3 @@ +def precompile_hock(*args, **kargs): + default_backends = kargs["default_backends"] + default_backends.append('flir') diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py new file mode 100644 index 000000000..bb0184ec6 --- /dev/null +++ b/python/setup_tools/utils/ascend.py @@ -0,0 +1,234 @@ +import os +import shutil +from pathlib import Path + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=" + src_ext_path, + ] + + +def install_extension(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.join(os.path.abspath(os.path.dirname(src_ext_path)), "triton-adapter-opt") + python_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + dst_ext_path = os.path.join(python_root_dir, "triton/backends/ascend/triton-adapter-opt") + shutil.copy(src_ext_path, dst_ext_path) + + +def get_package_dir(): + package_dict = {} + triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" + package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dict + + +def insert_at_file_start(filepath, import_lines): + import tempfile + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_file.write(import_lines + '\n\n') + with open(filepath, 'r') as original_file: + tmp_file.write(original_file.read()) + backup_path = filepath + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def append_at_file_end(filepath, import_lines): + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with open(filepath, 'a', encoding='utf-8') as f: + f.write('\n' + import_lines) + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def post_install(): + import site + install_dir = site.getsitepackages()[0] + install_dir = os.path.join(install_dir, "triton") + init_path = os.path.join(install_dir, "__init__.py") + patched_content = """ +import sys +from .triton_patch.language import _utils as ascend_utils +sys.modules['triton.language._utils'] = ascend_utils +from .triton_patch.compiler import compiler as ascend_compiler +sys.modules['triton.compiler.compiler'] = ascend_compiler +from .triton_patch.compiler import code_generator as ascend_code_generator +sys.modules['triton.compiler.code_generator'] = ascend_code_generator +from .triton_patch.compiler import errors as ascend_errors +sys.modules['triton.compiler.errors'] = ascend_errors +from .triton_patch.runtime import autotuner as ascend_autotuner +sys.modules['triton.runtime.autotuner'] = ascend_autotuner +from .triton_patch import testing as ascend_testing +sys.modules['triton.testing'] = ascend_testing +""" + insert_at_file_start(init_path, patched_content) + + content_to_append = """ +from .triton_patch.language.core import dot, gather, insert, subview +from .triton_patch.language.standard import flip +from .triton_patch.language.math import umulhi, exp, exp2, log, log2, cos, sin, sqrt, sqrt_rn, rsqrt, div_rn, erf, tanh, floor, ceil +from . import language + +language.dot = dot +language.flip = flip +language.gather = gather +language.insert = insert +language.subview = subview + +# from .triton_patch.language.core import dtype, pointer_type, block_type, function_type +# language.core.dtype = dtype +# language.core.pointer_type = pointer_type +# language.core.block_type = block_type +# language.core.function_type = function_type + +from .triton_patch.language.semantic import arange, floordiv +language.semantic.arange = arange +language.semantic.floordiv = floordiv + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +""" + append_at_file_end(init_path, content_to_append) + + +def get_ascend_patch_packages(backends): + packages = [] + # packages += get_language_extra_packages() + packages += [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + return packages + + +def get_ascend_patch_package_dir(backends): + package_dir = {} + # language_extra_list = get_language_extra_packages() + # for extra_full in language_extra_list: + # extra_name = extra_full.replace("triton/language/extra/", "") + # package_dir[extra_full] = f"{triton_root_rel_dir}/language/extra/{extra_name}" + # + triton_patch_root_rel_dir = "triton_patch/python/triton_patch" + package_dir["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dir["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dir["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dir["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dir + + +def get_extra_install_packages(): + return [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + + +def precompile_hock(*args, **kargs): + third_party_base_dir = Path(kargs['third_party_base_dir']) + ascend_path = Path(third_party_base_dir) / "ascend" + patch_path = Path(ascend_path) / "triton_patch" + project_path = Path(third_party_base_dir) / "triton_ascend" + if os.path.exists(ascend_path): + shutil.rmtree(ascend_path) + if not os.path.exists(project_path): + raise RuntimeError(f"{project_path} can't be found. It might be due to a network issue") + ascend_src_path = Path(project_path) / "ascend" + patch_src_path = Path(project_path) / "triton_patch" + shutil.copytree(ascend_src_path, ascend_path, dirs_exist_ok=True) + shutil.copytree(patch_src_path, patch_path, dirs_exist_ok=True) + shutil.rmtree(project_path) + patched_code = """ set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") """ + src_code = """set(triton_abs_dir""" + + filepath = Path(patch_path) / "include" / "triton" / "Dialect" / "Triton" / "IR" / "CMakeLists.txt" + try: + import tempfile + with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tmp_file: + with open(filepath, 'r') as file: + lines = file.readlines() + for line in lines: + if src_code in line: + tmp_file.writelines(patched_code) + else: + tmp_file.writelines(line) + backup_path = str(filepath) + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False diff --git a/python/setup_tools/utils/cambricon.py b/python/setup_tools/utils/cambricon.py new file mode 100644 index 000000000..3246fc695 --- /dev/null +++ b/python/setup_tools/utils/cambricon.py @@ -0,0 +1,4 @@ +def skip_package_dir(package): + if package not in ['triton', 'triton/_C']: + return True + return False diff --git a/python/setup_tools/utils/default.py b/python/setup_tools/utils/default.py new file mode 100644 index 000000000..5aba98a0f --- /dev/null +++ b/python/setup_tools/utils/default.py @@ -0,0 +1,3 @@ +def precompile_hock(*args, **kargs): + default_backends = kargs['default_backends'] + default_backends.append('triton_shared') diff --git a/python/setup_tools/utils/tsingmicro.py b/python/setup_tools/utils/tsingmicro.py new file mode 100644 index 000000000..6107ea83b --- /dev/null +++ b/python/setup_tools/utils/tsingmicro.py @@ -0,0 +1,10 @@ +import os + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_INSTALL_PREFIX=" + src_ext_path, + ] diff --git a/python/setup_tools/utils/xpu.py b/python/setup_tools/utils/xpu.py new file mode 100644 index 000000000..92424b1b2 --- /dev/null +++ b/python/setup_tools/utils/xpu.py @@ -0,0 +1,2 @@ +def get_package_data_tools(): + return ["compile_xpu.h", "compile_xpu.c"]