diff --git a/python/setup.py b/python/setup.py index 0fe2fa3df..e2e5bc21e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -24,7 +24,7 @@ from setuptools.command.egg_info import egg_info from wheel.bdist_wheel import bdist_wheel -import setup_helper as helper +from setup_tools import setup_helper as helper @dataclass @@ -372,6 +372,7 @@ def build_extension(self, ext): "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] + helper.get_backend_cmake_args(build_ext=self) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -428,6 +429,7 @@ def build_extension(self, ext): subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) + helper.install_extension(build_ext=self) nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.txt") @@ -520,6 +522,7 @@ class plugin_install(install): def run(self): add_links() install.run(self) + helper.post_install() class plugin_develop(develop): @@ -527,6 +530,7 @@ class plugin_develop(develop): def run(self): add_links() develop.run(self) + helper.post_install() class plugin_bdist_wheel(bdist_wheel): @@ -534,6 +538,7 @@ class plugin_bdist_wheel(bdist_wheel): def run(self): add_links() bdist_wheel.run(self) + helper.post_install() class plugin_egginfo(egg_info): @@ -541,11 +546,11 @@ class plugin_egginfo(egg_info): def run(self): add_links() egg_info.run(self) + helper.post_install() -package_data_tools = ["compile.h", "compile.c"] -if helper.flagtree_backend == "xpu": - package_data_tools += ["compile_xpu.h", "compile_xpu.c"] +package_data_tools = helper.get_package_data_tools() + package_data = { "triton/tools": package_data_tools, **{f"triton/backends/{b.name}": b.package_data @@ -568,10 +573,7 @@ def get_packages(): "triton/backends", "triton/tools", ] - if helper.flagtree_backend == "xpu": - packages.append("triton/language/extra/xpu") - elif helper.flagtree_backend == "mthreads": - packages.append("triton/language/extra/musa") + packages += helper.get_language_extra() packages += [f'triton/backends/{backend.name}' for backend in backends] packages += ["triton/profiler"] return packages diff --git a/python/setup_tools/__init__.py b/python/setup_tools/__init__.py new file mode 100644 index 000000000..c3411313f --- /dev/null +++ b/python/setup_tools/__init__.py @@ -0,0 +1,4 @@ +from . import setup_helper +from . import utils + +__all__ = ["setup_helper", "utils"] diff --git a/python/setup_helper.py b/python/setup_tools/setup_helper.py similarity index 66% rename from python/setup_helper.py rename to python/setup_tools/setup_helper.py index 262b18e5a..eb6e1d94d 100644 --- a/python/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -8,32 +8,20 @@ import urllib.request from pathlib import Path import hashlib -from dataclasses import dataclass +from distutils.sysconfig import get_python_lib +from . import utils -use_triton_shared = True -necessary_third_party = ["triton_shared"] -default_backends = ["nvidia", "amd"] extend_backends = [] +default_backends = ["nvidia", "amd"] +plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"] ext_sourcedir = "triton/_C/" flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() - - -@dataclass -class FlagTreeBackend: - name: str - url: str - tag: str - - -flagtree_backend_info = { - "triton_shared": - FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", - tag="380b87122c88af131530903a702d5318ec59bb33"), - "cambricon": - FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", - tag="00f51c2e48a943922f86f03d58e29f514def646d"), -} +offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF") +device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} +language_extra_backends = ['xpu', 'musa'] +flagtree_backends = utils.flagtree_backends +backend_utils = utils.activate(flagtree_backend) set_llvm_env = lambda path: set_env({ 'LLVM_INCLUDE_DIRS': Path(path) / "include", @@ -42,6 +30,110 @@ class FlagTreeBackend: }) +def install_extension(*args, **kargs): + try: + backend_utils.install_extension(*args, **kargs) + except Exception: + pass + + +def get_backend_cmake_args(*args, **kargs): + try: + return backend_utils.get_backend_cmake_args(*args, **kargs) + except Exception: + return [] + + +def get_device_name(): + return device_mapping[flagtree_backend] + + +def get_extra_packages(): + packages = [] + try: + packages = backend_utils.get_extra_install_packages() + except Exception: + packages = [] + return packages + + +def get_language_extra(): + packages = [] + if flagtree_backend in language_extra_backends: + device_name = device_mapping[flagtree_backend] + extra_path = f"triton/language/extra/{device_name}" + packages.append(extra_path) + return packages + + +def get_package_data_tools(): + package_data = ["compile.h", "compile.c"] + try: + package_data += backend_utils.get_package_data_tools() + except Exception: + package_data + return package_data + + +def git_clone(lib, lib_path): + import git + MAX_RETRY = 4 + print(f"Clone {lib.name} into {lib_path} ...") + retry_count = MAX_RETRY + while (retry_count): + try: + repo = git.Repo.clone_from(lib.url, lib_path) + if lib.tag is not None: + repo.git.checkout(lib.tag) + sub_triton_path = Path(lib_path) / "triton" + if os.path.exists(sub_triton_path): + shutil.rmtree(sub_triton_path) + print(f"successfully clone {lib.name} into {lib_path} ...") + return True + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") + return False + + +def dir_rollback(deep, base_path): + while (deep): + base_path = os.path.dirname(base_path) + deep -= 1 + return Path(base_path) + + +def download_flagtree_third_party(name, condition, required=False, hock=None): + if not condition: + return + backend = None + for _backend in flagtree_backends: + if _backend.name in name: + backend = _backend + break + if backend is None: + return backend + base_dir = dir_rollback(3, __file__) / "third_party" + prelib_path = Path(base_dir) / name + lib_path = Path(base_dir) / _backend.name + + if not os.path.exists(prelib_path) and not os.path.exists(lib_path): + succ = git_clone(lib=backend, lib_path=prelib_path) + if not succ and required: + raise RuntimeError("Bad network ! ") + else: + print(f'Found third_party {backend.name} at {lib_path}\n') + if callable(hock): + hock(third_party_base_dir=base_dir, backend=backend, default_backends=default_backends) + + +def post_install(): + try: + backend_utils.post_install() + except Exception: + pass + + class FlagTreeCache: def __init__(self): @@ -208,8 +300,13 @@ class CommonUtils: @staticmethod def unlink(): - cur_path = os.path.dirname(__file__) - backends_dir_path = Path(cur_path) / "triton" / "backends" + cur_path = dir_rollback(2, __file__) + if "editable_wheel" in sys.argv: + installation_dir = cur_path + else: + installation_dir = get_python_lib() + backends_dir_path = Path(installation_dir) / "triton" / "backends" + # raise RuntimeError(backends_dir_path) if not os.path.exists(backends_dir_path): return for name in os.listdir(backends_dir_path): @@ -227,15 +324,15 @@ def unlink(): def skip_package_dir(package): if 'backends' in package or 'profiler' in package: return True - if flagtree_backend in ['cambricon']: - if package not in ['triton', 'triton/_C']: - return True - return False + try: + return backend_utils.skip_package_dir(package) + except Exception: + return False @staticmethod def get_package_dir(packages): package_dict = {} - if flagtree_backend and flagtree_backend != 'cambricon': + if flagtree_backend and flagtree_backend not in plugin_backends: connection = [] backend_triton_path = f"../third_party/{flagtree_backend}/python/" for package in packages: @@ -244,66 +341,20 @@ def get_package_dir(packages): pair = (package, f"{backend_triton_path}{package}") connection.append(pair) package_dict.update(connection) + try: + package_dict.update(backend_utils.get_package_dir()) + except Exception: + pass return package_dict - @staticmethod - def download_third_party(): - import git - MAX_RETRY = 4 - global use_triton_shared, flagtree_backend - third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party" - - def git_clone(lib, lib_path): - global use_triton_shared - print(f"Clone {lib.name} into {lib_path} ...") - retry_count = MAX_RETRY - while (retry_count): - try: - repo = git.Repo.clone_from(lib.url, lib_path) - repo.git.checkout(lib.tag) - if lib.name in flagtree_backend_info: - sub_triton_path = Path(lib_path) / "triton" - if os.path.exists(sub_triton_path): - shutil.rmtree(sub_triton_path) - print(f"successfully clone {lib.name} into {lib_path} ...") - return - except Exception: - retry_count -= 1 - print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") - - print(f"Unable to clone third_party {lib.name}") - if lib.name in necessary_third_party: - use_triton_shared = False - print("\n\ttriton_shared is compiled by default, but for " - "some reason we couldn't download triton_shared\n" - "as third_party (most likely for network reasons), " - "so we couldn't compile triton_shared\n") - - third_partys = [] - if os.environ.get("USE_TRITON_SHARED", "ON") == "ON" and not flagtree_backend: - third_partys.append(flagtree_backend_info["triton_shared"]) - else: - use_triton_shared = False - if flagtree_backend in flagtree_backend_info: - third_partys.append(flagtree_backend_info[flagtree_backend]) - - for lib in third_partys: - lib_path = Path(third_party_base_dir) / lib.name - if not os.path.exists(lib_path): - git_clone(lib=lib, lib_path=lib_path) - else: - print(f'Found third_party {lib.name} at {lib_path}\n') - def handle_flagtree_backend(): global ext_sourcedir if flagtree_backend: - print(f"flagtree_backend is {flagtree_backend}") + print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m") extend_backends.append(flagtree_backend) - if "editable_wheel" in sys.argv: + if "editable_wheel" in sys.argv and flagtree_backend != "ascend": ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" - if use_triton_shared and not flagtree_backend: - default_backends.append("triton_shared") def set_env(env_dict: dict): @@ -315,8 +366,18 @@ def check_env(env_val): return os.environ.get(env_val, '') != '' -CommonUtils.download_third_party() +download_flagtree_third_party("triton_shared", hock=utils.default.precompile_hock, condition=(not flagtree_backend)) + +download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"), + hock=utils.ascend.precompile_hock, required=True) + +download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True) + +download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock, + required=True) + handle_flagtree_backend() + cache = FlagTreeCache() # iluvatar @@ -367,3 +428,31 @@ def check_env(env_val): pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) + +# ascend +cache.store( + file="ascend-llvm-b5cc222d-ubuntu-arm64", + condition=("ascend" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# aipu +cache.store( + file="aipu-llvm-a66376b0-ubuntu-x64", + condition=("aipu" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# tsingmicro +cache.store( + file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64", + condition=("tsingmicro" == flagtree_backend), + url= + "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py new file mode 100644 index 000000000..190fc98e0 --- /dev/null +++ b/python/setup_tools/utils/__init__.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from pathlib import Path +import importlib.util +import os +from . import ascend, aipu, cambricon, default, tsingmicro, xpu + + +@dataclass +class FlagTreeBackend: + name: str + url: str + tag: str = None + + +flagtree_backends = ( + FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", + tag="380b87122c88af131530903a702d5318ec59bb33"), + FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", + tag="00f51c2e48a943922f86f03d58e29f514def646d"), + FlagTreeBackend( + name="ascend", + url="https://gitee.com/ascend/triton-ascend.git", + ), +) + + +def activate(backend, suffix=".py"): + if not backend: + return + module_path = Path(os.path.dirname(__file__)) / backend + module_path = str(module_path) + suffix + spec = importlib.util.spec_from_file_location("module", module_path) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception: + pass + return module + + +__all__ = ["ascend", "aipu", "cambricon", "default", "tsingmicro", "xpu"] diff --git a/python/setup_tools/utils/aipu.py b/python/setup_tools/utils/aipu.py new file mode 100644 index 000000000..f4a81f6a5 --- /dev/null +++ b/python/setup_tools/utils/aipu.py @@ -0,0 +1,3 @@ +def precompile_hock(*args, **kargs): + default_backends = kargs["default_backends"] + default_backends.append('flir') diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py new file mode 100644 index 000000000..bb0184ec6 --- /dev/null +++ b/python/setup_tools/utils/ascend.py @@ -0,0 +1,234 @@ +import os +import shutil +from pathlib import Path + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=" + src_ext_path, + ] + + +def install_extension(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.join(os.path.abspath(os.path.dirname(src_ext_path)), "triton-adapter-opt") + python_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + dst_ext_path = os.path.join(python_root_dir, "triton/backends/ascend/triton-adapter-opt") + shutil.copy(src_ext_path, dst_ext_path) + + +def get_package_dir(): + package_dict = {} + triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" + package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dict + + +def insert_at_file_start(filepath, import_lines): + import tempfile + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_file.write(import_lines + '\n\n') + with open(filepath, 'r') as original_file: + tmp_file.write(original_file.read()) + backup_path = filepath + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def append_at_file_end(filepath, import_lines): + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with open(filepath, 'a', encoding='utf-8') as f: + f.write('\n' + import_lines) + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def post_install(): + import site + install_dir = site.getsitepackages()[0] + install_dir = os.path.join(install_dir, "triton") + init_path = os.path.join(install_dir, "__init__.py") + patched_content = """ +import sys +from .triton_patch.language import _utils as ascend_utils +sys.modules['triton.language._utils'] = ascend_utils +from .triton_patch.compiler import compiler as ascend_compiler +sys.modules['triton.compiler.compiler'] = ascend_compiler +from .triton_patch.compiler import code_generator as ascend_code_generator +sys.modules['triton.compiler.code_generator'] = ascend_code_generator +from .triton_patch.compiler import errors as ascend_errors +sys.modules['triton.compiler.errors'] = ascend_errors +from .triton_patch.runtime import autotuner as ascend_autotuner +sys.modules['triton.runtime.autotuner'] = ascend_autotuner +from .triton_patch import testing as ascend_testing +sys.modules['triton.testing'] = ascend_testing +""" + insert_at_file_start(init_path, patched_content) + + content_to_append = """ +from .triton_patch.language.core import dot, gather, insert, subview +from .triton_patch.language.standard import flip +from .triton_patch.language.math import umulhi, exp, exp2, log, log2, cos, sin, sqrt, sqrt_rn, rsqrt, div_rn, erf, tanh, floor, ceil +from . import language + +language.dot = dot +language.flip = flip +language.gather = gather +language.insert = insert +language.subview = subview + +# from .triton_patch.language.core import dtype, pointer_type, block_type, function_type +# language.core.dtype = dtype +# language.core.pointer_type = pointer_type +# language.core.block_type = block_type +# language.core.function_type = function_type + +from .triton_patch.language.semantic import arange, floordiv +language.semantic.arange = arange +language.semantic.floordiv = floordiv + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +""" + append_at_file_end(init_path, content_to_append) + + +def get_ascend_patch_packages(backends): + packages = [] + # packages += get_language_extra_packages() + packages += [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + return packages + + +def get_ascend_patch_package_dir(backends): + package_dir = {} + # language_extra_list = get_language_extra_packages() + # for extra_full in language_extra_list: + # extra_name = extra_full.replace("triton/language/extra/", "") + # package_dir[extra_full] = f"{triton_root_rel_dir}/language/extra/{extra_name}" + # + triton_patch_root_rel_dir = "triton_patch/python/triton_patch" + package_dir["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dir["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dir["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dir["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dir + + +def get_extra_install_packages(): + return [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + + +def precompile_hock(*args, **kargs): + third_party_base_dir = Path(kargs['third_party_base_dir']) + ascend_path = Path(third_party_base_dir) / "ascend" + patch_path = Path(ascend_path) / "triton_patch" + project_path = Path(third_party_base_dir) / "triton_ascend" + if os.path.exists(ascend_path): + shutil.rmtree(ascend_path) + if not os.path.exists(project_path): + raise RuntimeError(f"{project_path} can't be found. It might be due to a network issue") + ascend_src_path = Path(project_path) / "ascend" + patch_src_path = Path(project_path) / "triton_patch" + shutil.copytree(ascend_src_path, ascend_path, dirs_exist_ok=True) + shutil.copytree(patch_src_path, patch_path, dirs_exist_ok=True) + shutil.rmtree(project_path) + patched_code = """ set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") """ + src_code = """set(triton_abs_dir""" + + filepath = Path(patch_path) / "include" / "triton" / "Dialect" / "Triton" / "IR" / "CMakeLists.txt" + try: + import tempfile + with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tmp_file: + with open(filepath, 'r') as file: + lines = file.readlines() + for line in lines: + if src_code in line: + tmp_file.writelines(patched_code) + else: + tmp_file.writelines(line) + backup_path = str(filepath) + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False diff --git a/python/setup_tools/utils/cambricon.py b/python/setup_tools/utils/cambricon.py new file mode 100644 index 000000000..3246fc695 --- /dev/null +++ b/python/setup_tools/utils/cambricon.py @@ -0,0 +1,4 @@ +def skip_package_dir(package): + if package not in ['triton', 'triton/_C']: + return True + return False diff --git a/python/setup_tools/utils/default.py b/python/setup_tools/utils/default.py new file mode 100644 index 000000000..5aba98a0f --- /dev/null +++ b/python/setup_tools/utils/default.py @@ -0,0 +1,3 @@ +def precompile_hock(*args, **kargs): + default_backends = kargs['default_backends'] + default_backends.append('triton_shared') diff --git a/python/setup_tools/utils/tsingmicro.py b/python/setup_tools/utils/tsingmicro.py new file mode 100644 index 000000000..6107ea83b --- /dev/null +++ b/python/setup_tools/utils/tsingmicro.py @@ -0,0 +1,10 @@ +import os + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_INSTALL_PREFIX=" + src_ext_path, + ] diff --git a/python/setup_tools/utils/xpu.py b/python/setup_tools/utils/xpu.py new file mode 100644 index 000000000..92424b1b2 --- /dev/null +++ b/python/setup_tools/utils/xpu.py @@ -0,0 +1,2 @@ +def get_package_data_tools(): + return ["compile_xpu.h", "compile_xpu.c"]