Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)
set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND})
elseif(FLAGTREE_BACKEND STREQUAL "ascend")
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)
endif()
set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}")
if(FLAGTREE_PLUGIN)
Expand All @@ -35,6 +38,19 @@ endif()
project(triton)
include(CTest)

if (FLAGTREE_BACKEND STREQUAL "ascend")
set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch")
set(PATCHED_TRITON_LIBRARIES
"TritonIR"
)
set(PATCHED_TRITON_DEPENDS
"TritonTableGen"
)
include_directories(${PATCHED_TRITON_ROOT_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files
endif()

if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()
Expand Down Expand Up @@ -112,18 +128,54 @@ function(add_triton_object name)
INTERFACE $<TARGET_OBJECTS:${name}>
)


# add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS})
if(ARG_DEPENDS)
add_dependencies(${name} ${ARG_DEPENDS})
endif()
if(ARG_LINK_LIBS)
target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS})
if (FLAGTREE_BACKEND STREQUAL "ascend")
set(patched_depends "")
foreach(dep ${ARG_DEPENDS})
list(FIND PATCHED_TRITON_DEPENDS "${dep}" index)
if(index GREATER_EQUAL 0)
list(APPEND patched_depends "Patched_${dep}")
message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}")
else()
list(APPEND patched_depends ${dep})
endif()
endforeach()
if(patched_depends)
add_dependencies(${name} ${patched_depends})
endif()

set(patched_link_libs "")
foreach(lib ${ARG_LINK_LIBS})
list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index)
if(index GREATER_EQUAL 0)
list(APPEND patched_link_libs "Patched_${lib}")
message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}")
else()
list(APPEND patched_link_libs ${lib})
endif()
endforeach()
if(patched_link_libs)
target_link_libraries(${name} PUBLIC ${patched_link_libs})
endif()
else()
#add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS})
if(ARG_DEPENDS)
add_dependencies(${name} ${ARG_DEPENDS})
endif()
if(ARG_LINK_LIBS)
target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS})
endif()
endif()
endfunction(add_triton_object)

set_property(GLOBAL PROPERTY TRITON_LIBS "")
function(add_triton_library name)
if (FLAGTREE_BACKEND STREQUAL "ascend")
list(FIND PATCHED_TRITON_LIBRARIES "${name}" index)
if(index GREATER_EQUAL 0)
message(STATUS "Adding Patched_${name} as a lib, instead of ${name}")
return()
endif()
endif()
set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name})
add_triton_object(${name} ${ARGN})
llvm_update_compile_flags(${name})
Expand Down Expand Up @@ -162,7 +214,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party)
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files

# link_directories(${LLVM_LIBRARY_DIR})
if (FLAGTREE_BACKEND STREQUAL "cambricon")
if (FLAGTREE_BACKEND STREQUAL "cambricon" OR FLAGTREE_BACKEND STREQUAL "ascend")
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
add_subdirectory(include)
Expand All @@ -172,6 +224,11 @@ elseif(NOT FLAGTREE_BACKEND)
add_subdirectory(lib)
endif()

if (FLAGTREE_BACKEND STREQUAL "ascend")
add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include)
add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib)
endif()

# find_package(PythonLibs REQUIRED)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
Expand All @@ -190,6 +247,10 @@ endif()
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/python/src)

set(PATCHED_PYTHON_SRC_PATH ${PATCHED_TRITON_ROOT_DIR}/python/src)
include_directories(${PYTHON_SRC_PATH})

if(NOT (FLAGTREE_BACKEND AND EXISTS "${PYTHON_SRC_PATH}"))
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
endif()
Expand Down Expand Up @@ -364,8 +425,13 @@ if(TRITON_BUILD_PYTHON_MODULE)
if(FLAGTREE_BACKEND STREQUAL "cambricon")
add_library(triton SHARED)
else()
if(FLAGTREE_BACKEND STREQUAL "ascend")
set(PYTHON_IR_SRC_PATH ${PATCHED_PYTHON_SRC_PATH})
else()
set(PYTHON_IR_SRC_PATH ${PYTHON_SRC_PATH})
endif()
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_IR_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/interpreter.cc
${PYTHON_SRC_PATH}/llvm.cc)
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ cd ${YOUR_CODE_DIR}/flagtree/python
export FLAGTREE_BACKEND=mthreads
python3 -m pip install . --no-build-isolation -v
```
```shell
# ascend
# manually download LLVM
cd ${YOUR_LLVM_DOWNLOAD_DIR}
# if the output of `uname -a` is x64 or x86_64
wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz
tar -zxvf llvm-b5cc222d-ubuntu-x64.tar.gz
export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-x64
# if the output of `uname -a` is aarch64
wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz
tar -zxvf llvm-b5cc222d-ubuntu-arm64.tar.gz
export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-arm64
# build
cd ${YOUR_CODE_DIR}/flagtree/python
export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include
export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib
export LLVM_SYSPATH=$LLVM_BUILD_DIR
export FLAGTREE_BACKEND=ascend
python3 -m pip install . --no-build-isolation -v
```

To build with default backends (nvidia, amd, triton_shared):
```shell
Expand Down
20 changes: 20 additions & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ cd ${YOUR_CODE_DIR}/flagtree/python
export FLAGTREE_BACKEND=mthreads
python3 -m pip install . --no-build-isolation -v
```
```shell
# ascend
# 自行下载 LLVM
cd ${YOUR_LLVM_DOWNLOAD_DIR}
# 如果 `uname -a` 的输出是 x64 或 x86_64
wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz
tar -zxvf llvm-b5cc222d-ubuntu-x64.tar.gz
export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-x64
# 如果 `uname -a` 的输出是 aarch64
wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz
tar -zxvf llvm-b5cc222d-ubuntu-arm64.tar.gz
export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-b5cc222d-ubuntu-arm64
# 编译安装
cd ${YOUR_CODE_DIR}/flagtree/python
export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include
export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib
export LLVM_SYSPATH=$LLVM_BUILD_DIR
export FLAGTREE_BACKEND=ascend
python3 -m pip install . --no-build-isolation -v
```

使用默认的编译命令,可以编译安装 nvidia、amd、triton_shared 后端:
```shell
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ wheel
GitPython
pytest
scipy
pybind11
16 changes: 9 additions & 7 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,32 +611,34 @@ class plugin_install(install):
def run(self):
add_links()
install.run(self)
helper.post_install(self)


class plugin_develop(develop):

def run(self):
add_links()
develop.run(self)
helper.post_install(self)


class plugin_bdist_wheel(bdist_wheel):

def run(self):
add_links()
bdist_wheel.run(self)
helper.post_install(self)


class plugin_egginfo(egg_info):

def run(self):
add_links()
egg_info.run(self)
helper.post_install(self)


package_data_tools = ["compile.h", "compile.c"]
if helper.flagtree_backend == "xpu":
package_data_tools += ["compile_xpu.h", "compile_xpu.c"]
package_data_tools = helper.get_package_data_tools()
package_data = {
"triton/tools": package_data_tools, **{f"triton/backends/{b.name}": b.package_data
for b in backends}, "triton/language/extra": sum(
Expand Down Expand Up @@ -676,10 +678,10 @@ def get_packages():
"triton/backends",
"triton/tools",
]
if helper.flagtree_backend == "xpu":
packages.append("triton/language/extra/xpu")
elif helper.flagtree_backend == "mthreads":
packages.append("triton/language/extra/musa")
if helper.flagtree_backend:
packages.append(f"triton/language/extra/{helper.get_device_name()}")
packages += helper.get_extra_packages()

packages += [f'triton/backends/{backend.name}' for backend in backends]
packages += get_language_extra_packages()
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
Expand Down
75 changes: 71 additions & 4 deletions python/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
from pathlib import Path
import hashlib
from dataclasses import dataclass
from distutils.sysconfig import get_python_lib

use_triton_shared = False
necessary_third_party = ["triton_shared"]
default_backends = ["nvidia", "amd"]
extend_backends = []
plugin_backends = ["cambricon", "ascend"]
ext_sourcedir = "triton/_C/"
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}


@dataclass
Expand All @@ -43,6 +46,51 @@ class FlagTreeBackend:
})


def get_device_name():
return device_mapping[flagtree_backend]


def get_extra_packages():
packages = []
if flagtree_backend == 'ascend':
packages = [
"triton/triton_patch",
"triton/triton_patch/language",
"triton/triton_patch/compiler",
"triton/triton_patch/runtime",
]
return packages


def get_package_data_tools():
package_data = ["compile.h", "compile.c"]
if flagtree_backend == 'xpu':
package_data += ["compile_xpu.h", "compile_xpu.c"]
return package_data


def post_install(self):

def get_module(module_path):
import importlib.util
import os
module_path = os.path.abspath(module_path)
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

def ascend():
utils = get_module("../third_party/ascend/utils.py")
utils.post_install()

code = f"{flagtree_backend}()"
try:
exec(code, globals(), locals())
except: #noqa: E722
pass


class FlagTreeCache:

def __init__(self):
Expand Down Expand Up @@ -210,7 +258,11 @@ class CommonUtils:
@staticmethod
def unlink():
cur_path = os.path.dirname(__file__)
backends_dir_path = Path(cur_path) / "triton" / "backends"
if "editable_wheel" in sys.argv:
installation_dir = cur_path
else:
installation_dir = get_python_lib()
backends_dir_path = Path(installation_dir) / "triton" / "backends"
if not os.path.exists(backends_dir_path):
return
for name in os.listdir(backends_dir_path):
Expand All @@ -236,7 +288,7 @@ def skip_package_dir(package):
@staticmethod
def get_package_dir(packages):
package_dict = {}
if flagtree_backend and flagtree_backend != 'cambricon':
if flagtree_backend and flagtree_backend not in plugin_backends:
connection = []
backend_triton_path = f"../third_party/{flagtree_backend}/python/"
for package in packages:
Expand All @@ -245,6 +297,12 @@ def get_package_dir(packages):
pair = (package, f"{backend_triton_path}{package}")
connection.append(pair)
package_dict.update(connection)
if flagtree_backend == "ascend":
triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch"
package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}"
package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language"
package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler"
package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime"
return package_dict

@staticmethod
Expand Down Expand Up @@ -299,9 +357,9 @@ def git_clone(lib, lib_path):
def handle_flagtree_backend():
global ext_sourcedir
if flagtree_backend:
print(f"flagtree_backend is {flagtree_backend}")
print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m")
extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv:
if "editable_wheel" in sys.argv and flagtree_backend != "ascend":
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
if use_triton_shared and not flagtree_backend:
default_backends.append("triton_shared")
Expand Down Expand Up @@ -368,3 +426,12 @@ def check_env(env_val):
pre_hock=lambda: check_env('LLVM_BUILD_DIR'),
post_hock=set_llvm_env,
)

# ascend
cache.store(
file="ascend-llvm-b5cc222d-ubuntu-x64.tar.gz",
condition=("ascend" == flagtree_backend),
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-x64.tar.gz",
pre_hock=lambda: check_env('LLVM_SYSPATH'),
post_hock=set_llvm_env,
)
1 change: 1 addition & 0 deletions third_party/ascend/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
triton-adapter-opt
12 changes: 12 additions & 0 deletions third_party/ascend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_subdirectory(triton-adapter triton-adapter)

add_triton_plugin(TritonHUAWEI ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp)

# Copy triton-adapter-opt to python files
add_custom_target(COPY_TRITON_ADAPTER_OPT)
add_custom_command(TARGET COPY_TRITON_ADAPTER_OPT POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:triton-adapter-opt>
${TRITON_ROOT_DIR}/python/triton/backends/ascend/triton-adapter-opt
DEPENDS triton-adapter-opt)
add_dependencies(TritonHUAWEI COPY_TRITON_ADAPTER_OPT)
2 changes: 2 additions & 0 deletions third_party/ascend/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
Loading