Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ascend-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
shell: bash
run: |
source /usr/local/Ascend/ascend-toolkit/set_env.sh
python3.9 third_party/ascend/python/tutorials/01-vector-add.py
python3.9 third_party/tests/ascend/vector-add.py
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ third_party/cambricon/
third_party/iluvatar/iluvatarTritonPlugin.so
third_party/triton_shared/
third_party/xpu/backend/xpu3
third_party/ascend

# Proton
python/triton/profiler
Expand Down
12 changes: 7 additions & 5 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import pybind11

import setup_helper as helper
from setup_tools import setup_helper as helper


@dataclass
Expand Down Expand Up @@ -423,6 +423,7 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
cmake_args += helper.get_backend_cmake_args(build_ext=self)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand Down Expand Up @@ -487,6 +488,7 @@ def build_extension(self, ext):
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
helper.install_extension(build_ext=self)


nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json")
Expand Down Expand Up @@ -611,31 +613,31 @@ class plugin_install(install):
def run(self):
add_links()
install.run(self)
helper.post_install(self)
helper.post_install()


class plugin_develop(develop):

def run(self):
add_links()
develop.run(self)
helper.post_install(self)
helper.post_install()


class plugin_bdist_wheel(bdist_wheel):

def run(self):
add_links()
bdist_wheel.run(self)
helper.post_install(self)
helper.post_install()


class plugin_egginfo(egg_info):

def run(self):
add_links()
egg_info.run(self)
helper.post_install(self)
helper.post_install()


package_data_tools = helper.get_package_data_tools()
Expand Down
4 changes: 4 additions & 0 deletions python/setup_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import setup_helper
from . import utils

__all__ = ["setup_helper", "utils"]
210 changes: 100 additions & 110 deletions python/setup_helper.py → python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,19 @@
import urllib.request
from pathlib import Path
import hashlib
from dataclasses import dataclass
from distutils.sysconfig import get_python_lib
from . import utils

use_triton_shared = False
necessary_third_party = ["triton_shared"]
default_backends = ["nvidia", "amd"]
extend_backends = []
default_backends = ["nvidia", "amd"]
plugin_backends = ["cambricon", "ascend"]
ext_sourcedir = "triton/_C/"
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}


@dataclass
class FlagTreeBackend:
name: str
url: str
tag: str


flagtree_backend_info = {
"triton_shared":
FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
tag="380b87122c88af131530903a702d5318ec59bb33"),
"cambricon":
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
}
flagtree_backends = utils.flagtree_backends
backend_utils = utils.activate(flagtree_backend)

set_llvm_env = lambda path: set_env({
'LLVM_INCLUDE_DIRS': Path(path) / "include",
Expand All @@ -45,48 +29,98 @@ class FlagTreeBackend:
})


def install_extension(*args, **kargs):
try:
backend_utils.install_extension(*args, **kargs)
except Exception:
pass


def get_backend_cmake_args(*args, **kargs):
try:
return backend_utils.get_backend_cmake_args(*args, **kargs)
except Exception:
return []


def get_device_name():
return device_mapping[flagtree_backend]


def get_extra_packages():
packages = []
if flagtree_backend == 'ascend':
packages = [
"triton/triton_patch",
"triton/triton_patch/language",
"triton/triton_patch/compiler",
"triton/triton_patch/runtime",
]
try:
packages = backend_utils.get_extra_install_packages()
except Exception:
packages = []
return packages


def get_package_data_tools():
package_data = ["compile.h", "compile.c"]
if flagtree_backend == 'xpu':
package_data += ["compile_xpu.h", "compile_xpu.c"]
try:
package_data += backend_utils.get_package_data_tools()
except Exception:
package_data
return package_data


def post_install(self):

def get_module(module_path):
import importlib.util
import os
module_path = os.path.abspath(module_path)
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

def ascend():
utils = get_module("../third_party/ascend/utils.py")
utils.post_install()

code = f"{flagtree_backend}()"
def git_clone(lib, lib_path):
import git
MAX_RETRY = 4
print(f"Clone {lib.name} into {lib_path} ...")
retry_count = MAX_RETRY
while (retry_count):
try:
repo = git.Repo.clone_from(lib.url, lib_path)
if lib.tag is not None:
repo.git.checkout(lib.tag)
sub_triton_path = Path(lib_path) / "triton"
if os.path.exists(sub_triton_path):
shutil.rmtree(sub_triton_path)
print(f"successfully clone {lib.name} into {lib_path} ...")
return True
except Exception:
retry_count -= 1
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")
return False


def dir_rollback(deep, base_path):
while (deep):
base_path = os.path.dirname(base_path)
deep -= 1
return Path(base_path)


def download_flagtree_third_party(name, condition, required=False, hock=None):
if not condition:
return
backend = None
for _backend in flagtree_backends:
if _backend.name in name:
backend = _backend
break
if backend is None:
return backend
base_dir = dir_rollback(3, __file__) / "third_party"
prelib_path = Path(base_dir) / name
lib_path = Path(base_dir) / _backend.name

if not os.path.exists(prelib_path) and not os.path.exists(lib_path):
succ = git_clone(lib=backend, lib_path=prelib_path)
if not succ and required:
raise RuntimeError("Bad network ! ")
if callable(hock):
hock(third_party_base_dir=base_dir, backend=backend)
else:
print(f'Found third_party {backend.name} at {lib_path}\n')


def post_install():
try:
exec(code, globals(), locals())
except: #noqa: E722
backend_utils.post_install()
except Exception:
pass


Expand Down Expand Up @@ -256,12 +290,13 @@ class CommonUtils:

@staticmethod
def unlink():
cur_path = os.path.dirname(__file__)
cur_path = dir_rollback(2, __file__)
if "editable_wheel" in sys.argv:
installation_dir = cur_path
else:
installation_dir = get_python_lib()
backends_dir_path = Path(installation_dir) / "triton" / "backends"
# raise RuntimeError(backends_dir_path)
if not os.path.exists(backends_dir_path):
return
for name in os.listdir(backends_dir_path):
Expand All @@ -279,10 +314,10 @@ def unlink():
def skip_package_dir(package):
if 'backends' in package or 'profiler' in package:
return True
if flagtree_backend in ['cambricon']:
if package not in ['triton', 'triton/_C']:
return True
return False
try:
return backend_utils.skip_package_dir(package)
except Exception:
return False

@staticmethod
def get_package_dir(packages):
Expand All @@ -296,62 +331,12 @@ def get_package_dir(packages):
pair = (package, f"{backend_triton_path}{package}")
connection.append(pair)
package_dict.update(connection)
if flagtree_backend == "ascend":
triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch"
package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}"
package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language"
package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler"
package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime"
try:
package_dict.update(backend_utils.get_package_dir())
except Exception:
pass
return package_dict

@staticmethod
def download_third_party():
import git
MAX_RETRY = 4
global use_triton_shared, flagtree_backend
third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party"

def git_clone(lib, lib_path):
global use_triton_shared
print(f"Clone {lib.name} into {lib_path} ...")
retry_count = MAX_RETRY
while (retry_count):
try:
repo = git.Repo.clone_from(lib.url, lib_path)
repo.git.checkout(lib.tag)
if lib.name in flagtree_backend_info:
sub_triton_path = Path(lib_path) / "triton"
if os.path.exists(sub_triton_path):
shutil.rmtree(sub_triton_path)
print(f"successfully clone {lib.name} into {lib_path} ...")
return
except Exception:
retry_count -= 1
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")

print(f"Unable to clone third_party {lib.name}")
if lib.name in necessary_third_party:
use_triton_shared = False
print("\n\ttriton_shared is compiled by default, but for "
"some reason we couldn't download triton_shared\n"
"as third_party (most likely for network reasons), "
"so we couldn't compile triton_shared\n")

third_partys = []
if os.environ.get("USE_TRITON_SHARED", "ON") == "ON" and not flagtree_backend:
third_partys.append(flagtree_backend_info["triton_shared"])
else:
use_triton_shared = False
if flagtree_backend in flagtree_backend_info:
third_partys.append(flagtree_backend_info[flagtree_backend])

for lib in third_partys:
lib_path = Path(third_party_base_dir) / lib.name
if not os.path.exists(lib_path):
git_clone(lib=lib, lib_path=lib_path)
else:
print(f'Found third_party {lib.name} at {lib_path}\n')


def handle_flagtree_backend():
global ext_sourcedir
Expand All @@ -360,8 +345,6 @@ def handle_flagtree_backend():
extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv and flagtree_backend != "ascend":
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
if use_triton_shared and not flagtree_backend:
default_backends.append("triton_shared")


def set_env(env_dict: dict):
Expand All @@ -373,8 +356,15 @@ def check_env(env_val):
return os.environ.get(env_val, '') != ''


CommonUtils.download_third_party()
download_flagtree_third_party("triton_shared", condition=(not flagtree_backend))

download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"),
hock=utils.ascend.precompile_hock, required=True)

download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True)

handle_flagtree_backend()

cache = FlagTreeCache()

# iluvatar
Expand Down
Loading