From 0c23042f21f23387e779ba1407181a863e408203 Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Tue, 17 Jun 2025 08:54:16 +0000 Subject: [PATCH 1/2] [refactor_code_3.3] --- python/setup.py | 11 +- python/setup_tools/__init__.py | 4 + python/setup_tools/setup_helper.py | 429 +++++++++++++++++++++++++ python/setup_tools/utils/__init__.py | 45 +++ python/setup_tools/utils/aipu.py | 3 + python/setup_tools/utils/ascend.py | 234 ++++++++++++++ python/setup_tools/utils/cambricon.py | 4 + python/setup_tools/utils/default.py | 3 + python/setup_tools/utils/tsingmicro.py | 10 + python/setup_tools/utils/xpu.py | 2 + 10 files changed, 741 insertions(+), 4 deletions(-) create mode 100644 python/setup_tools/__init__.py create mode 100644 python/setup_tools/setup_helper.py create mode 100644 python/setup_tools/utils/__init__.py create mode 100644 python/setup_tools/utils/aipu.py create mode 100644 python/setup_tools/utils/ascend.py create mode 100644 python/setup_tools/utils/cambricon.py create mode 100644 python/setup_tools/utils/default.py create mode 100644 python/setup_tools/utils/tsingmicro.py create mode 100644 python/setup_tools/utils/xpu.py diff --git a/python/setup.py b/python/setup.py index a84b896fd..999933281 100644 --- a/python/setup.py +++ b/python/setup.py @@ -29,7 +29,7 @@ import pybind11 from build_helpers import get_base_dir, get_cmake_dir -import setup_helper as helper +from setup_tools import setup_helper as helper @dataclass @@ -400,7 +400,6 @@ def run(self): cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor")) if (cmake_major, cmake_minor) < (3, 18): raise RuntimeError("CMake >= 3.18.0 is required") - for ext in self.extensions: self.build_extension(ext) @@ -432,7 +431,6 @@ def build_extension(self, ext): thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) - ext_base_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) # create build directories if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) @@ -449,6 +447,7 @@ def build_extension(self, ext): "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] + cmake_args += helper.get_backend_cmake_args(build_ext=self) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -472,7 +471,6 @@ def build_extension(self, ext): "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", - f"-DCMAKE_INSTALL_PREFIX={ext_base_dir}", ] # Note that asan doesn't work with binaries that use the GPU, so this is @@ -515,6 +513,7 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) subprocess.check_call(["cmake", "--install", "."], cwd=cmake_dir) + helper.install_extension(build_ext=self) nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") @@ -652,6 +651,7 @@ class plugin_install(install): def run(self): add_links() install.run(self) + helper.post_install() class plugin_develop(develop): @@ -659,6 +659,7 @@ class plugin_develop(develop): def run(self): add_links() develop.run(self) + helper.post_install() class plugin_bdist_wheel(bdist_wheel): @@ -666,6 +667,7 @@ class plugin_bdist_wheel(bdist_wheel): def run(self): add_links() bdist_wheel.run(self) + helper.post_install() class plugin_egginfo(egg_info): @@ -673,6 +675,7 @@ class plugin_egginfo(egg_info): def run(self): add_links() egg_info.run(self) + helper.post_install() # TODO: package_data_tools diff --git a/python/setup_tools/__init__.py b/python/setup_tools/__init__.py new file mode 100644 index 000000000..c3411313f --- /dev/null +++ b/python/setup_tools/__init__.py @@ -0,0 +1,4 @@ +from . import setup_helper +from . import utils + +__all__ = ["setup_helper", "utils"] diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py new file mode 100644 index 000000000..cc5f83a4c --- /dev/null +++ b/python/setup_tools/setup_helper.py @@ -0,0 +1,429 @@ +import os +import shutil +import sys +import functools +import tarfile +import zipfile +from io import BytesIO +import urllib.request +from pathlib import Path +import hashlib +from distutils.sysconfig import get_python_lib +from . import utils + +extend_backends = [] +default_backends = ["nvidia", "amd"] +plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"] +ext_sourcedir = "triton/_C/" +flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() +flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() +offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF") +device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} +flagtree_backends = utils.flagtree_backends +backend_utils = utils.activate(flagtree_backend) + +set_llvm_env = lambda path: set_env({ + 'LLVM_INCLUDE_DIRS': Path(path) / "include", + 'LLVM_LIBRARY_DIR': Path(path) / "lib", + 'LLVM_SYSPATH': path, +}) + + +def install_extension(*args, **kargs): + try: + backend_utils.install_extension(*args, **kargs) + except Exception: + pass + + +def get_backend_cmake_args(*args, **kargs): + try: + return backend_utils.get_backend_cmake_args(*args, **kargs) + except Exception: + return [] + + +def get_device_name(): + return device_mapping[flagtree_backend] + + +def get_extra_packages(): + packages = [] + try: + packages = backend_utils.get_extra_install_packages() + except Exception: + packages = [] + return packages + + +def get_package_data_tools(): + package_data = ["compile.h", "compile.c"] + try: + package_data += backend_utils.get_package_data_tools() + except Exception: + package_data + return package_data + + +def git_clone(lib, lib_path): + import git + MAX_RETRY = 4 + print(f"Clone {lib.name} into {lib_path} ...") + retry_count = MAX_RETRY + while (retry_count): + try: + repo = git.Repo.clone_from(lib.url, lib_path) + if lib.tag is not None: + repo.git.checkout(lib.tag) + sub_triton_path = Path(lib_path) / "triton" + if os.path.exists(sub_triton_path): + shutil.rmtree(sub_triton_path) + print(f"successfully clone {lib.name} into {lib_path} ...") + return True + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") + return False + + +def dir_rollback(deep, base_path): + while (deep): + base_path = os.path.dirname(base_path) + deep -= 1 + return Path(base_path) + + +def download_flagtree_third_party(name, condition, required=False, hock=None): + if not condition: + return + backend = None + for _backend in flagtree_backends: + if _backend.name in name: + backend = _backend + break + if backend is None: + return backend + base_dir = dir_rollback(3, __file__) / "third_party" + prelib_path = Path(base_dir) / name + lib_path = Path(base_dir) / _backend.name + if not os.path.exists(prelib_path) and not os.path.exists(lib_path): + succ = git_clone(lib=backend, lib_path=prelib_path) + if not succ and required: + raise RuntimeError("Bad network ! ") + else: + print(f'Found third_party {backend.name} at {lib_path}\n') + + if callable(hock): + hock(third_party_base_dir=base_dir, backend=backend, default_backends=default_backends) + + +def post_install(): + try: + backend_utils.post_install() + except Exception: + pass + + +class FlagTreeCache: + + def __init__(self): + self.flagtree_dir = os.path.dirname(os.getcwd()) + self.dir_name = ".flagtree" + self.sub_dirs = {} + self.cache_files = {} + self.dir_path = self._get_cache_dir_path() + self._create_cache_dir() + if flagtree_backend: + self._create_subdir(subdir_name=flagtree_backend) + + @functools.lru_cache(maxsize=None) + def _get_cache_dir_path(self) -> Path: + _cache_dir = os.environ.get("FLAGTREE_CACHE_DIR") + if _cache_dir is None: + _cache_dir = Path.home() / self.dir_name + else: + _cache_dir = Path(_cache_dir) + return _cache_dir + + def _create_cache_dir(self) -> Path: + if not os.path.exists(self.dir_path): + os.makedirs(self.dir_path, exist_ok=True) + + def _create_subdir(self, subdir_name, path=None): + if path is None: + subdir_path = Path(self.dir_path) / subdir_name + else: + subdir_path = Path(path) / subdir_name + + if not os.path.exists(subdir_path): + os.makedirs(subdir_path, exist_ok=True) + self.sub_dirs[subdir_name] = subdir_path + + def _md5(self, file_path): + md5_hash = hashlib.md5() + with open(file_path, "rb") as file: + while chunk := file.read(4096): + md5_hash.update(chunk) + return md5_hash.hexdigest() + + def _download(self, url, path, file_name): + MAX_RETRY_COUNT = 4 + user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0' + headers = { + 'User-Agent': user_agent, + } + request = urllib.request.Request(url, None, headers) + retry_count = MAX_RETRY_COUNT + content = None + print(f'downloading {url} ...') + while (retry_count): + try: + with urllib.request.urlopen(request, timeout=300) as response: + content = response.read() + break + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY_COUNT - retry_count}] retry to downloading and extracting {url}") + + if retry_count == 0: + raise RuntimeError("The download failed, probably due to network problems") + + print(f'extracting {url} ...') + file_bytes = BytesIO(content) + file_names = [] + if url.endswith(".zip"): + with zipfile.ZipFile(file_bytes, "r") as file: + file.extractall(path=path) + file_names = file.namelist() + else: + with tarfile.open(fileobj=file_bytes, mode="r|*") as file: + file.extractall(path=path) + file_names = file.getnames() + os.rename(Path(path) / file_names[0], Path(path) / file_name) + + def check_file(self, file_name=None, url=None, path=None, md5_digest=None): + origin_file_path = None + if url is not None: + origin_file_name = url.split("/")[-1].split('.')[0] + origin_file_path = self.cache_files.get(origin_file_name, "") + if path is not None: + _path = path + else: + _path = self.cache_files.get(file_name, "") + empty = (not os.path.exists(_path)) or (origin_file_path and not os.path.exists(origin_file_path)) + if empty: + return False + if md5_digest is None: + return True + else: + cur_md5 = self._md5(_path) + return cur_md5[:8] == md5_digest + + def clear(self): + shutil.rmtree(self.dir_path) + + def reverse_copy(self, src_path, cache_file_path, md5_digest): + if src_path is None or not os.path.exists(src_path): + return False + if os.path.exists(cache_file_path): + return False + copy_needed = True + if md5_digest is None or self._md5(src_path) == md5_digest: + copy_needed = False + if copy_needed: + print(f"copying {src_path} to {cache_file_path}") + if os.path.isdir(src_path): + shutil.copytree(src_path, cache_file_path, dirs_exist_ok=True) + else: + shutil.copy(src_path, cache_file_path) + return True + return False + + def store(self, file=None, condition=None, url=None, copy_src_path=None, copy_dst_path=None, files=None, + md5_digest=None, pre_hock=None, post_hock=None): + + if not condition or (pre_hock and pre_hock()): + return + is_url = False if url is None else True + path = self.sub_dirs[flagtree_backend] if flagtree_backend else self.dir_path + + if files is not None: + for single_files in files: + self.cache_files[single_files] = Path(path) / single_files + else: + self.cache_files[file] = Path(path) / file + if url is not None: + origin_file_name = url.split("/")[-1].split('.')[0] + self.cache_files[origin_file_name] = Path(path) / file + if copy_dst_path is not None: + dst_path_root = Path(self.flagtree_dir) / copy_dst_path + dst_path = Path(dst_path_root) / file + if self.reverse_copy(dst_path, self.cache_files[file], md5_digest): + return + + if is_url and not self.check_file(file_name=file, url=url, md5_digest=md5_digest): + self._download(url, path, file_name=file) + + if copy_dst_path is not None: + file_lists = [file] if files is None else list(files) + for single_file in file_lists: + dst_path_root = Path(self.flagtree_dir) / copy_dst_path + os.makedirs(dst_path_root, exist_ok=True) + dst_path = Path(dst_path_root) / single_file + if not self.check_file(path=dst_path, md5_digest=md5_digest): + if copy_src_path: + src_path = Path(copy_src_path) / single_file + else: + src_path = self.cache_files[single_file] + print(f"copying {src_path} to {dst_path}") + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy(src_path, dst_path) + post_hock(self.cache_files[file]) if post_hock else False + + def get(self, file_name) -> Path: + return self.cache_files[file_name] + + +class CommonUtils: + + @staticmethod + def unlink(): + cur_path = dir_rollback(2, __file__) + if "editable_wheel" in sys.argv: + installation_dir = cur_path + else: + installation_dir = get_python_lib() + backends_dir_path = Path(installation_dir) / "triton" / "backends" + # raise RuntimeError(backends_dir_path) + if not os.path.exists(backends_dir_path): + return + for name in os.listdir(backends_dir_path): + exist_backend_path = os.path.join(backends_dir_path, name) + if not os.path.isdir(exist_backend_path): + continue + if name.startswith('__'): + continue + if os.path.islink(exist_backend_path): + os.unlink(exist_backend_path) + if os.path.exists(exist_backend_path): + shutil.rmtree(exist_backend_path) + + @staticmethod + def skip_package_dir(package): + if 'backends' in package or 'profiler' in package: + return True + try: + return backend_utils.skip_package_dir(package) + except Exception: + return False + + @staticmethod + def get_package_dir(packages): + package_dict = {} + if flagtree_backend and flagtree_backend not in plugin_backends: + connection = [] + backend_triton_path = f"../third_party/{flagtree_backend}/python/" + for package in packages: + if CommonUtils.skip_package_dir(package): + continue + pair = (package, f"{backend_triton_path}{package}") + connection.append(pair) + package_dict.update(connection) + try: + package_dict.update(backend_utils.get_package_dir()) + except Exception: + pass + return package_dict + + +def handle_flagtree_backend(): + global ext_sourcedir + if flagtree_backend: + print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m") + extend_backends.append(flagtree_backend) + if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends: + ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" + + +def set_env(env_dict: dict): + for env_k, env_v in env_dict.items(): + os.environ[env_k] = str(env_v) + + +def check_env(env_val): + return os.environ.get(env_val, '') != '' + + +download_flagtree_third_party("triton_shared", hock=utils.default.precompile_hock, condition=(not flagtree_backend)) + +download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"), + hock=utils.ascend.precompile_hock, required=True) + +download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True) + +download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock, + required=True) + +handle_flagtree_backend() + +cache = FlagTreeCache() + +# iluvatar +cache.store( + file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (flagtree_plugin == ''), url= + "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64.tar.gz", + copy_dst_path="third_party/iluvatar", md5_digest="7d4e136c") + +cache.store( + file="iluvatar-llvm18-x86_64", + condition=("iluvatar" == flagtree_backend), + url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# xpu(kunlunxin) +cache.store( + file="XTDK-llvm18-ubuntu2004_x86_64", + condition=("xpu" == flagtree_backend), + url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm18-ubuntu2004_x86_64.tar", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +cache.store(file="xre-Linux-x86_64", condition=("xpu" == flagtree_backend), + url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/xre-Linux-x86_64.tar.gz", + copy_dst_path='python/_deps/xre3') + +cache.store( + files=("clang", "xpu-xxd", "xpu3-crt.xpu", "xpu-kernel.t", "ld.lld", "llvm-readelf", "llvm-objdump", + "llvm-objcopy"), condition=("xpu" == flagtree_backend), + copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") + +cache.store(files=("libclang_rt.builtins-xpu3.a", "libclang_rt.builtins-xpu3s.a"), + condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/lib/linux", + copy_dst_path="third_party/xpu/backend/xpu3/lib/linux") + +cache.store(files=("include", "so"), condition=("xpu" == flagtree_backend), + copy_src_path=f"{cache.dir_path}/xpu/xre-Linux-x86_64", copy_dst_path="third_party/xpu/backend/xpu3") + +# mthreads +cache.store( + file="mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64", + condition=("mthreads" == flagtree_backend), + url= + "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# ascend +cache.store( + file="ascend-llvm-b5cc222d-ubuntu-arm64", + condition=("ascend" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py new file mode 100644 index 000000000..00c86e5a4 --- /dev/null +++ b/python/setup_tools/utils/__init__.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from pathlib import Path +import importlib.util +import os +from . import ascend +from . import aipu +from . import default + + +@dataclass +class FlagTreeBackend: + name: str + url: str + tag: str = None + + +flagtree_backends = ( + FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", + tag="5842469a16b261e45a2c67fbfc308057622b03ee"), + FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", + tag="00f51c2e48a943922f86f03d58e29f514def646d"), + FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git", + tag="e72b83ba46a5a9dd6466c7102f93fd600cde909e"), + FlagTreeBackend( + name="ascend", + url="https://gitee.com/ascend/triton-ascend.git", + ), +) + + +def activate(backend, suffix=".py"): + if not backend: + return + module_path = Path(os.path.dirname(__file__)) / backend + module_path = str(module_path) + suffix + spec = importlib.util.spec_from_file_location("module", module_path) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception: + pass + return module + + +__all__ = ["ascend", "aipu", "default", "activate"] diff --git a/python/setup_tools/utils/aipu.py b/python/setup_tools/utils/aipu.py new file mode 100644 index 000000000..f4a81f6a5 --- /dev/null +++ b/python/setup_tools/utils/aipu.py @@ -0,0 +1,3 @@ +def precompile_hock(*args, **kargs): + default_backends = kargs["default_backends"] + default_backends.append('flir') diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py new file mode 100644 index 000000000..bb0184ec6 --- /dev/null +++ b/python/setup_tools/utils/ascend.py @@ -0,0 +1,234 @@ +import os +import shutil +from pathlib import Path + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=" + src_ext_path, + ] + + +def install_extension(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") + src_ext_path = os.path.join(os.path.abspath(os.path.dirname(src_ext_path)), "triton-adapter-opt") + python_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + dst_ext_path = os.path.join(python_root_dir, "triton/backends/ascend/triton-adapter-opt") + shutil.copy(src_ext_path, dst_ext_path) + + +def get_package_dir(): + package_dict = {} + triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" + package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dict + + +def insert_at_file_start(filepath, import_lines): + import tempfile + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_file.write(import_lines + '\n\n') + with open(filepath, 'r') as original_file: + tmp_file.write(original_file.read()) + backup_path = filepath + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def append_at_file_end(filepath, import_lines): + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with open(filepath, 'a', encoding='utf-8') as f: + f.write('\n' + import_lines) + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def post_install(): + import site + install_dir = site.getsitepackages()[0] + install_dir = os.path.join(install_dir, "triton") + init_path = os.path.join(install_dir, "__init__.py") + patched_content = """ +import sys +from .triton_patch.language import _utils as ascend_utils +sys.modules['triton.language._utils'] = ascend_utils +from .triton_patch.compiler import compiler as ascend_compiler +sys.modules['triton.compiler.compiler'] = ascend_compiler +from .triton_patch.compiler import code_generator as ascend_code_generator +sys.modules['triton.compiler.code_generator'] = ascend_code_generator +from .triton_patch.compiler import errors as ascend_errors +sys.modules['triton.compiler.errors'] = ascend_errors +from .triton_patch.runtime import autotuner as ascend_autotuner +sys.modules['triton.runtime.autotuner'] = ascend_autotuner +from .triton_patch import testing as ascend_testing +sys.modules['triton.testing'] = ascend_testing +""" + insert_at_file_start(init_path, patched_content) + + content_to_append = """ +from .triton_patch.language.core import dot, gather, insert, subview +from .triton_patch.language.standard import flip +from .triton_patch.language.math import umulhi, exp, exp2, log, log2, cos, sin, sqrt, sqrt_rn, rsqrt, div_rn, erf, tanh, floor, ceil +from . import language + +language.dot = dot +language.flip = flip +language.gather = gather +language.insert = insert +language.subview = subview + +# from .triton_patch.language.core import dtype, pointer_type, block_type, function_type +# language.core.dtype = dtype +# language.core.pointer_type = pointer_type +# language.core.block_type = block_type +# language.core.function_type = function_type + +from .triton_patch.language.semantic import arange, floordiv +language.semantic.arange = arange +language.semantic.floordiv = floordiv + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +""" + append_at_file_end(init_path, content_to_append) + + +def get_ascend_patch_packages(backends): + packages = [] + # packages += get_language_extra_packages() + packages += [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + return packages + + +def get_ascend_patch_package_dir(backends): + package_dir = {} + # language_extra_list = get_language_extra_packages() + # for extra_full in language_extra_list: + # extra_name = extra_full.replace("triton/language/extra/", "") + # package_dir[extra_full] = f"{triton_root_rel_dir}/language/extra/{extra_name}" + # + triton_patch_root_rel_dir = "triton_patch/python/triton_patch" + package_dir["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dir["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dir["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dir["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dir + + +def get_extra_install_packages(): + return [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + + +def precompile_hock(*args, **kargs): + third_party_base_dir = Path(kargs['third_party_base_dir']) + ascend_path = Path(third_party_base_dir) / "ascend" + patch_path = Path(ascend_path) / "triton_patch" + project_path = Path(third_party_base_dir) / "triton_ascend" + if os.path.exists(ascend_path): + shutil.rmtree(ascend_path) + if not os.path.exists(project_path): + raise RuntimeError(f"{project_path} can't be found. It might be due to a network issue") + ascend_src_path = Path(project_path) / "ascend" + patch_src_path = Path(project_path) / "triton_patch" + shutil.copytree(ascend_src_path, ascend_path, dirs_exist_ok=True) + shutil.copytree(patch_src_path, patch_path, dirs_exist_ok=True) + shutil.rmtree(project_path) + patched_code = """ set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") """ + src_code = """set(triton_abs_dir""" + + filepath = Path(patch_path) / "include" / "triton" / "Dialect" / "Triton" / "IR" / "CMakeLists.txt" + try: + import tempfile + with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tmp_file: + with open(filepath, 'r') as file: + lines = file.readlines() + for line in lines: + if src_code in line: + tmp_file.writelines(patched_code) + else: + tmp_file.writelines(line) + backup_path = str(filepath) + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False diff --git a/python/setup_tools/utils/cambricon.py b/python/setup_tools/utils/cambricon.py new file mode 100644 index 000000000..3246fc695 --- /dev/null +++ b/python/setup_tools/utils/cambricon.py @@ -0,0 +1,4 @@ +def skip_package_dir(package): + if package not in ['triton', 'triton/_C']: + return True + return False diff --git a/python/setup_tools/utils/default.py b/python/setup_tools/utils/default.py new file mode 100644 index 000000000..17c95b82b --- /dev/null +++ b/python/setup_tools/utils/default.py @@ -0,0 +1,3 @@ +def precompile_hock(*args, **kargs): + default_backends = kargs['default_backends'] + default_backends.append['triton_shared'] diff --git a/python/setup_tools/utils/tsingmicro.py b/python/setup_tools/utils/tsingmicro.py new file mode 100644 index 000000000..6107ea83b --- /dev/null +++ b/python/setup_tools/utils/tsingmicro.py @@ -0,0 +1,10 @@ +import os + + +def get_backend_cmake_args(*args, **kargs): + build_ext = kargs['build_ext'] + src_ext_path = build_ext.get_ext_fullpath("triton") + src_ext_path = os.path.abspath(os.path.dirname(src_ext_path)) + return [ + "-DCMAKE_INSTALL_PREFIX=" + src_ext_path, + ] diff --git a/python/setup_tools/utils/xpu.py b/python/setup_tools/utils/xpu.py new file mode 100644 index 000000000..92424b1b2 --- /dev/null +++ b/python/setup_tools/utils/xpu.py @@ -0,0 +1,2 @@ +def get_package_data_tools(): + return ["compile_xpu.h", "compile_xpu.c"] From e903d11dbe09337b1dbaa052c258180cbfd0e3ff Mon Sep 17 00:00:00 2001 From: yrl <2535184404@qq.com> Date: Tue, 17 Jun 2025 09:07:46 +0000 Subject: [PATCH 2/2] fix bugs --- python/setup_helper.py | 395 ---------------------------- python/setup_tools/setup_helper.py | 19 ++ python/setup_tools/utils/default.py | 2 +- 3 files changed, 20 insertions(+), 396 deletions(-) delete mode 100644 python/setup_helper.py diff --git a/python/setup_helper.py b/python/setup_helper.py deleted file mode 100644 index ebc371e8c..000000000 --- a/python/setup_helper.py +++ /dev/null @@ -1,395 +0,0 @@ -import os -import shutil -import sys -import functools -import tarfile -import zipfile -from io import BytesIO -import urllib.request -from pathlib import Path -import hashlib -from dataclasses import dataclass - -flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() -flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() -use_triton_shared = False -necessary_third_party = ["" if flagtree_backend == "tsingmicro" else "flir"] -default_backends = ["nvidia", "amd"] -extend_backends = [] -ext_sourcedir = "triton/_C/" - - -@dataclass -class FlagTreeBackend: - name: str - url: str - tag: str - - -flagtree_backend_info = { - "flir": - FlagTreeBackend(name="flir", url="git@github.com:FlagTree/flir.git", - tag="e72b83ba46a5a9dd6466c7102f93fd600cde909e"), - "triton_shared": - FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", - tag="5842469a16b261e45a2c67fbfc308057622b03ee"), - "cambricon": - FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", - tag="00f51c2e48a943922f86f03d58e29f514def646d"), -} - -set_llvm_env = lambda path: set_env({ - 'LLVM_INCLUDE_DIRS': Path(path) / "include", - 'LLVM_LIBRARY_DIR': Path(path) / "lib", - 'LLVM_SYSPATH': path, -}) - - -class FlagTreeCache: - - def __init__(self): - self.flagtree_dir = os.path.dirname(os.getcwd()) - self.dir_name = ".flagtree" - self.sub_dirs = {} - self.cache_files = {} - self.dir_path = self._get_cache_dir_path() - self._create_cache_dir() - if flagtree_backend: - self._create_subdir(subdir_name=flagtree_backend) - - @functools.lru_cache(maxsize=None) - def _get_cache_dir_path(self) -> Path: - _cache_dir = os.environ.get("FLAGTREE_CACHE_DIR") - if _cache_dir is None: - _cache_dir = Path.home() / self.dir_name - else: - _cache_dir = Path(_cache_dir) - return _cache_dir - - def _create_cache_dir(self) -> Path: - if not os.path.exists(self.dir_path): - os.makedirs(self.dir_path, exist_ok=True) - - def _create_subdir(self, subdir_name, path=None): - if path is None: - subdir_path = Path(self.dir_path) / subdir_name - else: - subdir_path = Path(path) / subdir_name - - if not os.path.exists(subdir_path): - os.makedirs(subdir_path, exist_ok=True) - self.sub_dirs[subdir_name] = subdir_path - - def _md5(self, file_path): - md5_hash = hashlib.md5() - with open(file_path, "rb") as file: - while chunk := file.read(4096): - md5_hash.update(chunk) - return md5_hash.hexdigest() - - def _download(self, url, path, file_name): - MAX_RETRY_COUNT = 4 - user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0' - headers = { - 'User-Agent': user_agent, - } - request = urllib.request.Request(url, None, headers) - retry_count = MAX_RETRY_COUNT - content = None - print(f'downloading {url} ...') - while (retry_count): - try: - with urllib.request.urlopen(request, timeout=300) as response: - content = response.read() - break - except Exception: - retry_count -= 1 - print(f"\n[{MAX_RETRY_COUNT - retry_count}] retry to downloading and extracting {url}") - - if retry_count == 0: - raise RuntimeError("The download failed, probably due to network problems") - - print(f'extracting {url} ...') - file_bytes = BytesIO(content) - file_names = [] - if url.endswith(".zip"): - with zipfile.ZipFile(file_bytes, "r") as file: - file.extractall(path=path) - file_names = file.namelist() - else: - with tarfile.open(fileobj=file_bytes, mode="r|*") as file: - file.extractall(path=path) - file_names = file.getnames() - os.rename(Path(path) / file_names[0], Path(path) / file_name) - - def check_file(self, file_name=None, url=None, path=None, md5_digest=None): - origin_file_path = None - if url is not None: - origin_file_name = url.split("/")[-1].split('.')[0] - origin_file_path = self.cache_files.get(origin_file_name, "") - if path is not None: - _path = path - else: - _path = self.cache_files.get(file_name, "") - empty = (not os.path.exists(_path)) or (origin_file_path and not os.path.exists(origin_file_path)) - if empty: - return False - if md5_digest is None: - return True - else: - cur_md5 = self._md5(_path) - return cur_md5[:8] == md5_digest - - def clear(self): - shutil.rmtree(self.dir_path) - - def reverse_copy(self, src_path, cache_file_path, md5_digest): - if src_path is None or not os.path.exists(src_path): - return False - if os.path.exists(cache_file_path): - return False - copy_needed = True - if md5_digest is None or self._md5(src_path) == md5_digest: - copy_needed = False - if copy_needed: - print(f"copying {src_path} to {cache_file_path}") - if os.path.isdir(src_path): - shutil.copytree(src_path, cache_file_path, dirs_exist_ok=True) - else: - shutil.copy(src_path, cache_file_path) - return True - return False - - def store(self, file=None, condition=None, url=None, copy_src_path=None, copy_dst_path=None, files=None, - md5_digest=None, pre_hock=None, post_hock=None): - - if not condition or (pre_hock and pre_hock()): - return - is_url = False if url is None else True - path = self.sub_dirs[flagtree_backend] if flagtree_backend else self.dir_path - - if files is not None: - for single_files in files: - self.cache_files[single_files] = Path(path) / single_files - else: - self.cache_files[file] = Path(path) / file - if url is not None: - origin_file_name = url.split("/")[-1].split('.')[0] - self.cache_files[origin_file_name] = Path(path) / file - if copy_dst_path is not None: - dst_path_root = Path(self.flagtree_dir) / copy_dst_path - dst_path = Path(dst_path_root) / file - if self.reverse_copy(dst_path, self.cache_files[file], md5_digest): - return - - if is_url and not self.check_file(file_name=file, url=url, md5_digest=md5_digest): - self._download(url, path, file_name=file) - - if copy_dst_path is not None: - file_lists = [file] if files is None else list(files) - for single_file in file_lists: - dst_path_root = Path(self.flagtree_dir) / copy_dst_path - os.makedirs(dst_path_root, exist_ok=True) - dst_path = Path(dst_path_root) / single_file - if not self.check_file(path=dst_path, md5_digest=md5_digest): - if copy_src_path: - src_path = Path(copy_src_path) / single_file - else: - src_path = self.cache_files[single_file] - print(f"copying {src_path} to {dst_path}") - if os.path.isdir(src_path): - shutil.copytree(src_path, dst_path, dirs_exist_ok=True) - else: - shutil.copy(src_path, dst_path) - post_hock(self.cache_files[file]) if post_hock else False - - def get(self, file_name) -> Path: - return self.cache_files[file_name] - - -class CommonUtils: - - @staticmethod - def unlink(): - cur_path = os.path.dirname(__file__) - backends_dir_path = Path(cur_path) / "triton" / "backends" - if not os.path.exists(backends_dir_path): - return - for name in os.listdir(backends_dir_path): - exist_backend_path = os.path.join(backends_dir_path, name) - if not os.path.isdir(exist_backend_path): - continue - if name.startswith('__'): - continue - if os.path.islink(exist_backend_path): - os.unlink(exist_backend_path) - if os.path.exists(exist_backend_path): - shutil.rmtree(exist_backend_path) - - @staticmethod - def skip_package_dir(package): - if 'backends' in package or 'profiler' in package: - return True - if flagtree_backend in ['cambricon']: - if package not in ['triton', 'triton/_C']: - return True - return False - - @staticmethod - def get_package_dir(packages): - package_dict = {} - if flagtree_backend and flagtree_backend not in ("cambricon", "aipu", "tsingmicro"): - connection = [] - backend_triton_path = f"../third_party/{flagtree_backend}/python/" - for package in packages: - if CommonUtils.skip_package_dir(package): - continue - pair = (package, f"{backend_triton_path}{package}") - connection.append(pair) - package_dict.update(connection) - return package_dict - - @staticmethod - def download_third_party(): - import git - MAX_RETRY = 4 - global use_triton_shared, flagtree_backend - third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party" - - def git_clone(lib, lib_path): - global use_triton_shared - print(f"Clone {lib.name} into {lib_path} ...") - retry_count = MAX_RETRY - while (retry_count): - try: - repo = git.Repo.clone_from(lib.url, lib_path) - repo.git.checkout(lib.tag) - if lib.name in flagtree_backend_info: - sub_triton_path = Path(lib_path) / "triton" - if os.path.exists(sub_triton_path): - shutil.rmtree(sub_triton_path) - print(f"successfully clone {lib.name} into {lib_path} ...") - return - except Exception: - retry_count -= 1 - print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") - - print(f"Unable to clone third_party {lib.name}") - if lib.name in necessary_third_party: - use_triton_shared = False # TODO - print(f"\n\t{lib.name} is compiled by default, but for " - "some reason we couldn't download triton_shared\n" - "as third_party (most likely for network reasons), " - "so we couldn't compile triton_shared\n") - - third_partys = [] - if flagtree_backend != "tsingmicro": - third_partys.append(flagtree_backend_info["flir"]) - if os.environ.get("USE_TRITON_SHARED", "ON") == "ON": - third_partys.append(flagtree_backend_info["triton_shared"]) - else: - use_triton_shared = False - if flagtree_backend in flagtree_backend_info: - third_partys.append(flagtree_backend_info[flagtree_backend]) - - for lib in third_partys: - lib_path = Path(third_party_base_dir) / lib.name - if not os.path.exists(lib_path): - git_clone(lib=lib, lib_path=lib_path) - else: - print(f'Found third_party {lib.name} at {lib_path}\n') - - -def handle_flagtree_backend(): - global ext_sourcedir - if flagtree_backend: - print(f"flagtree_backend is {flagtree_backend}") - extend_backends.append(flagtree_backend) - if "editable_wheel" in sys.argv and flagtree_backend not in ("aipu", "tsingmicro"): - ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" - if flagtree_backend != "tsingmicro": - default_backends.append("flir") - if use_triton_shared: - default_backends.append("triton_shared") - - -def set_env(env_dict: dict): - for env_k, env_v in env_dict.items(): - os.environ[env_k] = str(env_v) - - -def check_env(env_val): - return os.environ.get(env_val, '') != '' - - -CommonUtils.download_third_party() -handle_flagtree_backend() -cache = FlagTreeCache() - -# iluvatar -cache.store( - file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (flagtree_plugin == ''), url= - "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64.tar.gz", - copy_dst_path="third_party/iluvatar", md5_digest="7d4e136c") - -cache.store( - file="iluvatar-llvm18-x86_64", - condition=("iluvatar" == flagtree_backend), - url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz", - pre_hock=lambda: check_env('LLVM_SYSPATH'), - post_hock=set_llvm_env, -) - -# xpu(kunlunxin) -cache.store( - file="XTDK-llvm18-ubuntu2004_x86_64", - condition=("xpu" == flagtree_backend), - url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm18-ubuntu2004_x86_64.tar", - pre_hock=lambda: check_env('LLVM_SYSPATH'), - post_hock=set_llvm_env, -) - -cache.store(file="xre-Linux-x86_64", condition=("xpu" == flagtree_backend), - url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/xre-Linux-x86_64.tar.gz", - copy_dst_path='python/_deps/xre3') - -cache.store( - files=("clang", "xpu-xxd", "xpu3-crt.xpu", "xpu-kernel.t", "ld.lld", "llvm-readelf", "llvm-objdump", - "llvm-objcopy"), condition=("xpu" == flagtree_backend), - copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") - -cache.store(files=("libclang_rt.builtins-xpu3.a", "libclang_rt.builtins-xpu3s.a"), - condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/lib/linux", - copy_dst_path="third_party/xpu/backend/xpu3/lib/linux") - -cache.store(files=("include", "so"), condition=("xpu" == flagtree_backend), - copy_src_path=f"{cache.dir_path}/xpu/xre-Linux-x86_64", copy_dst_path="third_party/xpu/backend/xpu3") - -# mthreads -cache.store( - file="mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64", - condition=("mthreads" == flagtree_backend), - url= - "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz", - pre_hock=lambda: check_env('LLVM_SYSPATH'), - post_hock=set_llvm_env, -) - -# aipu -cache.store( - file="aipu-llvm-a66376b0-ubuntu-x64", - condition=("aipu" == flagtree_backend), - url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-x64.tar.gz", - pre_hock=lambda: check_env('LLVM_SYSPATH'), - post_hock=set_llvm_env, -) - -# tsingmicro -cache.store( - file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64", - condition=("tsingmicro" == flagtree_backend), - url= - "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64.tar.gz", - pre_hock=lambda: check_env('LLVM_SYSPATH'), - post_hock=set_llvm_env, -) diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py index cc5f83a4c..f9a40a0b7 100644 --- a/python/setup_tools/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -427,3 +427,22 @@ def check_env(env_val): pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) + +# aipu +cache.store( + file="aipu-llvm-a66376b0-ubuntu-x64", + condition=("aipu" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# tsingmicro +cache.store( + file="tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64", + condition=("tsingmicro" == flagtree_backend), + url= + "https://github.com/FlagTree/flagtree/releases/download/v0.2.0-build-deps/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) diff --git a/python/setup_tools/utils/default.py b/python/setup_tools/utils/default.py index 17c95b82b..5aba98a0f 100644 --- a/python/setup_tools/utils/default.py +++ b/python/setup_tools/utils/default.py @@ -1,3 +1,3 @@ def precompile_hock(*args, **kargs): default_backends = kargs['default_backends'] - default_backends.append['triton_shared'] + default_backends.append('triton_shared')